{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n",
      "2024-04-02 11:54:03,526\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "WARNING:evotorch:The logger is already configured. The default configuration will not be applied. Call `set_default_logger_config` with `override=True` to override the current configuration.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "from copy import deepcopy\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from utils.data import load_dataset, make_dataset_splits, load_dataset_splits\n",
    "from utils.split import SplitManager, node_induced_subgraph\n",
    "from utils.storage import TensorHash\n",
    "from utils.model import load_model_class, accuracy, load_model_instance, create_model_instance\n",
    "from utils.attack import load_attack_class, attack_storage_label, create_attack_instance\n",
    "\n",
    "from robust_diffusion.data import SparseGraph\n",
    "from robust_diffusion.data import count_edges_for_idx\n",
    "from robust_diffusion.helper import utils as robust_utils\n",
    "from robust_diffusion.train import train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "## Experiment configs\n",
    "dataset_name = \"cora_ml\"\n",
    "model_name = \"GCN\"\n",
    "n_splits = 50\n",
    "\n",
    "training_split = None\n",
    "validation_split = None\n",
    "training_split_type = None\n",
    "validation_split_type = None\n",
    "\n",
    "model_params = None\n",
    "epsilon = 0.1\n",
    "\n",
    "attack_name = \"EvaAttack\"\n",
    "attack_params = None\n",
    "\n",
    "inductive = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Experiment Started\n",
      "Experiment Started\n",
      "Loading dataset = cora_ml\n",
      "Found 50 splits!\n",
      "Loading pretrained GCN model on cora_ml dataset for 50 splits\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/torch_geometric/data/in_memory_dataset.py:157: UserWarning: It is not recommended to directly access the internal storage format `data` of an 'InMemoryDataset'. If you are absolutely certain what you are doing, access the internal storage via `InMemoryDataset._data` instead to suppress this warning. Alternatively, you can access stacked individual attributes of every graph via `dataset.{attr_name}`.\n",
      "  warnings.warn(msg)\n"
     ]
    }
   ],
   "source": [
    "print(\"Experiment Started\")\n",
    "\n",
    "## Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"conf/model-configs.yaml\")).get(\"configs\")\n",
    "default_attack_configs = yaml.safe_load(open(\"conf/attack-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "\n",
    "# extracting configs \n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "results_root = general_config.get(\"results_root\", \"results/\")\n",
    "reports_root = general_config.get(\"reports_root\", \"reports/\")\n",
    "    \n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "print(\"Experiment Started\")\n",
    "# Trains the specified model on the given graph and saves the model artifacts, and the splits.\n",
    "\n",
    "print(\"Loading dataset =\", dataset_name)\n",
    "\n",
    "dataset_splits = [split_record for split_record in os.listdir(splits_root) if split_record.split(\"-\")[0] == dataset_name]\n",
    "creating_splits = max(n_splits - len(dataset_splits), 0)\n",
    "\n",
    "if creating_splits > 0:\n",
    "    raise ValueError(\"Not enough splits for the dataset. Create the splits by running training scripts.\")\n",
    "\n",
    "# creating remaining needed dataset splits\n",
    "print(f\"Found {len(dataset_splits)} splits!\")\n",
    "\n",
    "print(f\"Loading pretrained {model_name} model on {dataset_name} dataset for {n_splits} splits\")\n",
    "\n",
    "clean_accs = []\n",
    "pert_accs = []\n",
    "# for split_file in tqdm(dataset_splits[:n_splits]):\n",
    "split_file = dataset_splits[0]\n",
    "split_code = split_file.split(\"-\")[1].replace(\".pt\", \"\")\n",
    "\n",
    "data = load_dataset_splits(\n",
    "    dataset_name, split_code, inductive=inductive, \n",
    "    dataset_root=dataset_root, splits_root=splits_root, device=device)\n",
    "\n",
    "training_attr = data[\"training_attr\"]\n",
    "training_adj = data[\"training_adj\"]\n",
    "labels = data[\"labels\"]\n",
    "training_idx = data[\"training_idx\"]\n",
    "validation_idx = data[\"validation_idx\"]\n",
    "test_attr = data[\"test_attr\"]\n",
    "test_adj = data[\"test_adj\"]\n",
    "test_mask = data[\"test_mask\"]\n",
    "test_idx = test_mask.nonzero(as_tuple=True)[0]\n",
    "dataset_info = data[\"dataset_info\"]\n",
    "split_name = data[\"split_name\"]\n",
    "\n",
    "try:\n",
    "    model_instance = load_model_instance(\n",
    "        model_name=model_name, model_params=model_params, \n",
    "        test_attr=test_attr, test_adj=test_adj, labels=labels, \n",
    "        test_mask=test_mask, split_name=split_name, dataset_info=dataset_info, \n",
    "        inductive=inductive,\n",
    "        models_root=models_root,\n",
    "        default_model_configs=default_model_configs, device=device)\n",
    "except FileNotFoundError as e:\n",
    "    print(e)\n",
    "    raise ValueError(\"Model not found. Run training scripts to train the model.\")\n",
    "\n",
    "\n",
    "model = model_instance[\"model\"]\n",
    "acc = model_instance[\"accuracy\"]\n",
    "model_params = model_instance[\"model_params\"]\n",
    "model_storage_name = model_instance[\"model_storage_name\"]\n",
    "clean_accs.append(acc)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "from eva_attack import EvaAttack\n",
    "from robust_diffusion.attacks.base_attack import SparseAttack\n",
    "from evotorch.operators.base import CopyingOperator\n",
    "from evotorch.core import Problem, SolutionBatch\n",
    "from evotorch import Problem\n",
    "from evotorch.algorithms import GeneticAlgorithm\n",
    "from evotorch import operators as evo_ops\n",
    "from evotorch.logging import StdOutLogger\n",
    "from copy import deepcopy\n",
    "\n",
    "\n",
    "class PositiveIntMutation(CopyingOperator):\n",
    "    def __init__(self, problem, mutation_rate=0.1, toggle_rate=0.5):\n",
    "        super().__init__(problem)\n",
    "        self.mutation_rate = mutation_rate\n",
    "        self.toggle_rate = toggle_rate\n",
    "    \n",
    "    @torch.no_grad()\n",
    "    def _do(self, batch: SolutionBatch) -> SolutionBatch:\n",
    "        result = deepcopy(batch)\n",
    "        data = result.access_values()\n",
    "        mutation_mask = torch.rand(size=data.shape, device=data.device) < self.mutation_rate\n",
    "        mutant_data = data[mutation_mask]\n",
    "        toggle_mutations = torch.rand(size=mutant_data.shape, device=mutant_data.device) < self.toggle_rate\n",
    "        new_vals = torch.randint(0, self.problem.upper_bounds, size=mutant_data.shape, device=data.device)\n",
    "        new_vals[(mutant_data >= 0) & toggle_mutations] = -1\n",
    "        new_vals[(mutant_data < 0) & (~toggle_mutations)] = -1\n",
    "        data[mutation_mask] = new_vals\n",
    "        return result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "attack_idx = test_mask.nonzero(as_tuple=True)[0]\n",
    "self_ = EvaAttack(attr=test_attr, adj=test_adj, labels=labels,\n",
    "        model=model, idx_attack=attack_idx.cpu().numpy(),\n",
    "        device=device, data_device=device, make_undirected=True, binary_attr=False,)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "evaluate_ga = lambda x: self_._evaluate_sparse_perturbation(\n",
    "            attr=self_.attr, adj=self_.adj, labels=self_.labels, \n",
    "            mask_attack=self_.mask_attack, perturbation=x, model=self_.model, device=self_.device)\n",
    "        \n",
    "problem = Problem(\n",
    "    \"min\", \n",
    "    objective_func=evaluate_ga,\n",
    "    solution_length=100,\n",
    "    dtype=torch.int64,\n",
    "    bounds=(0, (self_.n_nodes * (self_.n_nodes - 1)) // 2), device=self_.device\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "searcher = GeneticAlgorithm(\n",
    "            problem, operators=[evo_ops.MultiPointCrossOver(problem, tournament_size=self_.tournament_size, num_points=self_.num_cross_over), \n",
    "                                PositiveIntMutation(problem, mutation_rate=self_.mutation_rate, toggle_rate=self_.mutation_toggle_rate)],\n",
    "            popsize=self_.num_population, re_evaluate=False\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "attr = self_.attr\n",
    "adj = self_.adj\n",
    "labels = self_.labels\n",
    "mask_attack = self_.mask_attack\n",
    "model = self_.model\n",
    "device = self_.device\n",
    "capacity = 3"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "searcher.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "solutions = searcher.population"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from evotorch.core import SolutionBatch\n",
    "\n",
    "\n",
    "class GraphEvalProblem(Problem):\n",
    "    def __init__(self, *args, **kwargs):\n",
    "        super().__init__(*args, **kwargs)\n",
    "        self.capacity = kwargs.get(\"capacity\", 1)\n",
    "        self.capacity = kwargs.get(\"capacity\", 1)\n",
    "        self.attack_class = kwargs.get(\"attack_class\")\n",
    "\n",
    "    def _evaluate_batch(self, batch: SolutionBatch):\n",
    "        evaluation_result = []\n",
    "        start_ind = 0\n",
    "        while start_ind < len(solutions) - 1:\n",
    "            end_ind = min(start_ind + capacity, len(solutions))\n",
    "            batch = solutions[start_ind:end_ind]\n",
    "            outputs = [\n",
    "            self.attack_class._create_perturbed_graph(attr=self.attack_class.attr, adj=self.attack_class.adj, \n",
    "                                          perturbation=perturbation.values, device=self.attack_class.device)\n",
    "                for perturbation in batch]\n",
    "            attr_instances = [output[0] for output in outputs]\n",
    "            adj_instances = [output[1] for output in outputs]\n",
    "            bulk_eval = self.attack_class.bulk_evaluation(\n",
    "                attr_list=attr_instances, adj_list=adj_instances, \n",
    "                labels_list=[self_.labels] * len(attr_instances),\n",
    "                model=self.attack_class.model, mask_attack=self.attack_class.mask_attack)\n",
    "            evaluation_result.append(bulk_eval)\n",
    "            start_ind = end_ind\n",
    "        evaluation_result = torch.cat(evaluation_result)\n",
    "        solutions.set_evals(evaluation_result)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "evaluation_result = []\n",
    "start_ind = 0\n",
    "while start_ind < len(solutions) - 1:\n",
    "    end_ind = min(start_ind + capacity, len(solutions))\n",
    "    batch = solutions[start_ind:end_ind]\n",
    "    outputs = [\n",
    "    self_._create_perturbed_graph(attr=self_.attr, adj=self_.adj, perturbation=perturbation.values, device=self_.device)\n",
    "        for perturbation in batch]\n",
    "    attr_instances = [output[0] for output in outputs]\n",
    "    adj_instances = [output[1] for output in outputs]\n",
    "    bulk_eval = self_.bulk_evaluation(\n",
    "        attr_list=attr_instances, adj_list=adj_instances, \n",
    "        labels_list=[self_.labels] * len(attr_instances),\n",
    "        model=self_.model, mask_attack=self_.mask_attack)\n",
    "    evaluation_result.append(bulk_eval)\n",
    "    start_ind = end_ind\n",
    "evaluation_result = torch.cat(evaluation_result)\n",
    "solutions.set_evals(evaluation_result)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.7993, 0.7993, 0.7993, 0.7996, 0.7996, 0.7996, 0.8000, 0.8000, 0.8000,\n",
       "        0.8000, 0.8000, 0.8004, 0.8004, 0.8004, 0.8004, 0.8004, 0.8007, 0.8007,\n",
       "        0.8007, 0.8007, 0.8007, 0.8007, 0.8007, 0.8007, 0.8011, 0.8011, 0.8011,\n",
       "        0.8011, 0.8011, 0.8011, 0.8011, 0.8011, 0.8011, 0.8011, 0.8011, 0.8011,\n",
       "        0.8011, 0.8011, 0.8011, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015,\n",
       "        0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015,\n",
       "        0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015, 0.8015,\n",
       "        0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018,\n",
       "        0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018,\n",
       "        0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018,\n",
       "        0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8018, 0.8022, 0.8022, 0.8022,\n",
       "        0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022,\n",
       "        0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022,\n",
       "        0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022,\n",
       "        0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8022,\n",
       "        0.8022, 0.8022, 0.8022, 0.8022, 0.8022, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8026,\n",
       "        0.8026, 0.8026, 0.8026, 0.8026, 0.8026, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029, 0.8029,\n",
       "        0.8029, 0.8029, 0.8029, 0.8029, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033, 0.8033,\n",
       "        0.8033, 0.8033, 0.8033, 0.8033, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037,\n",
       "        0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8037, 0.8041,\n",
       "        0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041,\n",
       "        0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041,\n",
       "        0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041, 0.8041,\n",
       "        0.8041, 0.8041, 0.8041, 0.8041, 0.8041], device='cuda:0')"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.7993, 0.7993, 0.7993], device='cuda:0')"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bulk_eval"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [],
   "source": [
    "bulk_eval = bulk_evaluation(self_, attr_list=attr_instances, adj_list=adj_instances,\n",
    "                                labels_list=[self_.labels] * len(attr_instances), \n",
    "                                model=self_.model, mask_attack=self_.mask_attack)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 45,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 2715])"
      ]
     },
     "execution_count": 45,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "bulk_eval[:, mask_attack.nonzero(as_tuple=True)[0]].shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(outputs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([500, 100])"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "solutions.values.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "perturbation1 = (solutions.values[1])\n",
    "attr_instance1, adj_instance1 = self_._create_perturbed_graph(\n",
    "    attr=attr, adj=adj, perturbation=perturbation1, device=device)\n",
    "perturbation2 = (solutions.values[2])\n",
    "attr_instance2, adj_instance2 = self_._create_perturbed_graph(\n",
    "    attr=attr, adj=adj, perturbation=perturbation2, device=device)\n",
    "perturbation3 = (solutions.values[3])\n",
    "attr_instance3, adj_instance3 = self_._create_perturbed_graph(\n",
    "    attr=attr, adj=adj, perturbation=perturbation3, device=device)\n",
    "\n",
    "attr_instances = [attr_instance1, attr_instance2, attr_instance3]\n",
    "adj_instances = [adj_instance1, adj_instance2, adj_instance3]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([0.8110, 0.8117, 0.8117], device='cuda:0')"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "    attr_list = attr_instances\n",
    "    adj_list = adj_instances\n",
    "    labels_list = [labels] * 3\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[   0,    1,    2,  ...,  997,  998,  999],\n",
       "        [1000, 1001, 1002,  ..., 1997, 1998, 1999],\n",
       "        [2000, 2001, 2002,  ..., 2997, 2998, 2999]])"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "torch.arange(0, 3000).view(3, -1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(torch.Size([8985, 2879]), torch.Size([2995, 2879]))"
      ]
     },
     "execution_count": 59,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "attr_group = torch.cat(attr_instances, dim=0)\n",
    "attr_group.shape, attr_instance1.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [],
   "source": [
    "def cat_block_sparse(*matrices, block_size):\n",
    "    if isinstance(block_size, int):\n",
    "        block_size = [block_size] * len(matrices)\n",
    "    index_offsets = [0] + list(torch.cumsum(torch.tensor(block_size), 0).numpy())\n",
    "    indices_list = [item.indices() for item in matrices]\n",
    "    indices_list = [indices_list[i] + index_offsets[i] for i in range(len(indices_list))]\n",
    "    cat_indices = torch.cat(indices_list, dim=1)\n",
    "    values_list = [item.values() for item in matrices]\n",
    "    cat_values = torch.cat(values_list, dim=0)\n",
    "    cat_indices.shape, cat_values.shape\n",
    "\n",
    "    result = torch.sparse_coo_tensor(\n",
    "        indices=cat_indices, values=cat_values, size=(sum(block_size), sum(block_size)))\n",
    "    return result\n",
    "\n",
    "adj_group = cat_block_sparse(*adj_instances, block_size=[adj_instance1.shape[0], adj_instance2.shape[0], adj_instance3.shape[0]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 62,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([8985])"
      ]
     },
     "execution_count": 62,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "self_.model(attr_group, adj_group).argmax(dim=-1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[tensor(indices=tensor([[   0,    0,    0,  ..., 2993, 2993, 2994],\n",
       "                        [1636, 1638, 2357,  ...,  745, 1865, 1452]]),\n",
       "        values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
       "        device='cuda:0', size=(2995, 2995), nnz=16516, layout=torch.sparse_coo),\n",
       " tensor(indices=tensor([[   0,    0,    0,  ..., 2993, 2993, 2994],\n",
       "                        [1636, 1638, 2357,  ...,  745, 1865, 1452]]),\n",
       "        values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
       "        device='cuda:0', size=(2995, 2995), nnz=16516, layout=torch.sparse_coo),\n",
       " tensor(indices=tensor([[   0,    0,    0,  ..., 2993, 2993, 2994],\n",
       "                        [1636, 1638, 2357,  ...,  745, 1865, 1452]]),\n",
       "        values=tensor([1., 1., 1.,  ..., 1., 1., 1.]),\n",
       "        device='cuda:0', size=(2995, 2995), nnz=16516, layout=torch.sparse_coo)]"
      ]
     },
     "execution_count": 35,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "adj_instances"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "adj_group = "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
