{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/opt/conda/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
      "  from .autonotebook import tqdm as notebook_tqdm\n",
      "WARNING:root:Cuda kernels could not loaded -> no CUDA support!\n",
      "2024-08-21 10:22:39,718\tINFO util.py:154 -- Missing packages: ['ipywidgets']. Run `pip install -U ipywidgets`, then restart the notebook server for rich notebook output.\n",
      "WARNING:evotorch:The logger is already configured. The default configuration will not be applied. Call `set_default_logger_config` with `override=True` to override the current configuration.\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import yaml\n",
    "from ml_collections import ConfigDict\n",
    "from tqdm import tqdm\n",
    "\n",
    "import torch\n",
    "import torch_geometric\n",
    "\n",
    "from gnn_setup.utils.configs_manager import refine_dataset_configs, refine_model_configs\n",
    "from gnn_setup.utils.storage import load_split_files\n",
    "from gnn_setup.setups.data import make_dataset_split, check_split_valid, load_dataset_split, load_dataset\n",
    "from gnn_setup.setups.models import load_model_class\n",
    "from gnn_setup.setups.models import make_trained_model, load_trained_model\n",
    "from gnn_setup.utils.tensors import set_seed"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "from gnn_setup.setups.data import load_attr_adj, splited_datasets\n",
    "from gnn_setup.gnns.helpers.train import train, train_inductive\n",
    "\n",
    "from gnn_setup.utils.storage import TensorHash, model_storage_label\n",
    "from gnn_setup.utils.metrics import accuracy_from_data as accuracy\n",
    "\n",
    "from gnn_setup.utils.robust_training_utils import count_edges_for_idx\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 56,
   "metadata": {},
   "outputs": [],
   "source": [
    "# General configs: dataset name, model name, etc.\n",
    "dataset_name = \"cora_ml\"\n",
    "model_name = \"MLP\"\n",
    "n_runs = 5 # TODO: previously it was n_splits.\n",
    "inductive = True \n",
    "\n",
    "# Configs for splits\n",
    "training_nodes = None # number of training nodes (if integer it should be per-class)\n",
    "validation_nodes = None \n",
    "training_split_type = None # it is either \"stratified\" or \"non-stratified\"\n",
    "validation_split_type = None\n",
    "test_nodes = None\n",
    "test_split_type = None\n",
    "\n",
    "model_configs = None # it is a dictionary of model parameters\n",
    "retrain_models = False\n",
    "save_models = True\n",
    "wandb_flag = True\n",
    "wandb_project = \"EAV-vanila-train\"\n",
    "wandb_entity = \"WANDB-Research\"\n",
    "\n",
    "seed= 10\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 57,
   "metadata": {},
   "outputs": [],
   "source": [
    "set_seed(seed)\n",
    "\n",
    "\n",
    "# Loading general configs (like dataset_root, etc.) and initial parameters\n",
    "general_config = yaml.safe_load(open(\"../conf/general-config.yaml\"))\n",
    "default_dataset_configs = yaml.safe_load(open(\"../conf/data-configs.yaml\")).get(\"configs\").get(\"default\")\n",
    "default_model_configs = yaml.safe_load(open(\"../conf/model-configs.yaml\")).get(\"configs\")\n",
    "\n",
    "# extracting directory paths\n",
    "dataset_root = general_config.get(\"dataset_root\", \"data/\")\n",
    "splits_root = general_config.get(\"splits_root\", \"splits/\")\n",
    "models_root = general_config.get(\"models_root\", \"models/\")\n",
    "\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
    "\n",
    "\n",
    "\n",
    "# region Loading dataset and splits\n",
    "refined_dataset_configs = refine_dataset_configs(\n",
    "    dataset_defaults=default_dataset_configs, \n",
    "    training_nodes=training_nodes, validation_nodes=validation_nodes, test_nodes=test_nodes, \n",
    "    training_split_type=training_split_type, validation_split_type=validation_split_type, test_split_type=test_split_type)\n",
    "\n",
    "training_nodes = refined_dataset_configs[\"training_nodes\"]\n",
    "validation_nodes = refined_dataset_configs[\"validation_nodes\"]\n",
    "test_nodes = refined_dataset_configs[\"test_nodes\"]\n",
    "training_split_type = refined_dataset_configs[\"training_split_type\"]\n",
    "validation_split_type = refined_dataset_configs[\"validation_split_type\"]\n",
    "test_split_type = refined_dataset_configs[\"test_split_type\"]\n",
    "\n",
    "dataset_split_files = load_split_files(\n",
    "    splits_root=splits_root, make_if_not_exists=True, dataset_name=dataset_name,\n",
    "    training_nodes=training_nodes, validation_nodes=validation_nodes, test_nodes=test_nodes,\n",
    "    training_split_type=training_split_type, validation_split_type=validation_split_type, \n",
    "    test_split_type=test_split_type,) \n",
    "\n",
    "dataset_splits = [load_dataset_split(\n",
    "    dataset_name=dataset_name, split_name=split_name, dataset_root=dataset_root, splits_root=splits_root, device=device\n",
    ") for split_name in dataset_split_files]\n",
    "\n",
    "if len(dataset_splits) < n_runs:\n",
    "    raise ValueError(\"No. Runs = {} is greater than available splits = {}\".format(n_runs, len(dataset_splits)))\n",
    "dataset_splits = dataset_splits[:n_runs]\n",
    "\n",
    "dataset, dataset_info = load_dataset(dataset_name, dataset_root)\n",
    "# endregion\n",
    "\n",
    "model_configs = refine_model_configs(model_name=\"GCN\", model_defaults=default_model_configs, \n",
    "                                    model_configs=model_configs,  dataset_info=dataset_info)\n",
    "\n",
    "accs = []\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 58,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'n_filters': 64,\n",
       " 'lr': 0.01,\n",
       " 'weight_decay': 0.001,\n",
       " 'patience': 200,\n",
       " 'max_epochs': 3000,\n",
       " 'n_features': 2879,\n",
       " 'n_classes': 7}"
      ]
     },
     "execution_count": 58,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "model_configs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 59,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MLP(torch.nn.Module):\n",
    "    def __init__(self, n_features, n_filters, n_classes, **kwargs):\n",
    "        super(MLP, self).__init__()\n",
    "        self.n_features = n_features\n",
    "        self.n_filters = n_filters\n",
    "        self.n_classes = n_classes\n",
    "        self.dropout_p = 0.0\n",
    "        \n",
    "\n",
    "        self.lin1 = torch.nn.Linear(self.n_features, self.n_filters)\n",
    "        self.lin2 = torch.nn.Linear(self.n_filters, self.n_classes)\n",
    "\n",
    "    def forward(self, x, adj):\n",
    "        x = torch.relu(self.lin1(x))\n",
    "        x = torch.dropout(x, p=self.dropout_p, train=self.training)\n",
    "        x = self.lin2(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 60,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training...:   7%|▋         | 209/3000 [00:00<00:03, 737.63it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's accuracy: 0.597864768683274 -- stored under the name: MLP-0d0663459b-ind-cora_ml-cora_ml-403876751f.pt-\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training...:   7%|▋         | 218/3000 [00:00<00:03, 746.09it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's accuracy: 0.6263345195729537 -- stored under the name: MLP-0d0663459b-ind-cora_ml-cora_ml-5ca55fdeb8.pt-\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training...:   7%|▋         | 215/3000 [00:00<00:03, 746.68it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's accuracy: 0.5907473309608541 -- stored under the name: MLP-0d0663459b-ind-cora_ml-cora_ml-2a459294c7.pt-\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training...:   7%|▋         | 209/3000 [00:00<00:03, 753.01it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's accuracy: 0.5266903914590747 -- stored under the name: MLP-0d0663459b-ind-cora_ml-cora_ml-bac36c0b22.pt-\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Training...:   7%|▋         | 208/3000 [00:00<00:03, 751.31it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Model's accuracy: 0.6690391459074733 -- stored under the name: MLP-0d0663459b-ind-cora_ml-cora_ml-7bac1a28fd.pt-\n"
     ]
    }
   ],
   "source": [
    "for split in dataset_splits:\n",
    "# split = dataset_splits[0]\n",
    "\n",
    "\n",
    "    # region loading the split\n",
    "    training_idx = split[\"training_idx\"].to(device)\n",
    "    validation_idx = split[\"validation_idx\"].to(device)\n",
    "    test_idx = split[\"test_idx\"].to(device)\n",
    "    unlabeled_idx = split[\"unlabeled_idx\"].to(device)\n",
    "    dataset_info = split[\"dataset_info\"]\n",
    "    split_name = split[\"split_name\"]\n",
    "    split_config = split[\"config\"]\n",
    "    # endregion\n",
    "\n",
    "    model = MLP(**model_configs).to(device)\n",
    "\n",
    "    training_dataset, validation_dataset, test_dataset = splited_datasets(\n",
    "            dataset, dataset_info=dataset_info, \n",
    "            training_idx=training_idx, validation_idx=validation_idx, test_idx=test_idx, unlabeled_idx=unlabeled_idx,\n",
    "            inductive=inductive)\n",
    "\n",
    "    training_attr, training_adj = load_attr_adj(training_dataset, training_idx, device=device)\n",
    "    validation_attr, validation_adj = load_attr_adj(validation_dataset, validation_idx, device=device)\n",
    "    test_attr, test_adj = load_attr_adj(dataset, test_idx, device=device)\n",
    "\n",
    "    if not inductive:\n",
    "        training_trace = train(\n",
    "            model=model, attr=training_attr.to(device), adj=training_adj.to(device), labels=training_dataset.y.to(device),\n",
    "            idx_train=training_idx, idx_val=validation_idx, display_step=100,\n",
    "            lr=model_configs.get(\"lr\", None), \n",
    "            weight_decay=model_configs.get(\"weight_decay\", None), \n",
    "            patience=model_configs.get(\"patience\", None),\n",
    "            max_epochs=model_configs.get(\"max_epochs\", None),\n",
    "        )\n",
    "    else:\n",
    "        training_trace = train_inductive(\n",
    "            model=model, attr_training=training_attr.to(device), attr_validation=validation_attr.to(device), \n",
    "            adj_training=training_adj.to(device), adj_validation=validation_adj.to(device),\n",
    "            labels_training=training_dataset.y.to(device), labels_validation=validation_dataset.y.to(device),\n",
    "            idx_train=training_idx, idx_val=validation_idx, display_step=100,\n",
    "            lr=model_configs.get(\"lr\", None),\n",
    "            weight_decay=model_configs.get(\"weight_decay\", None),\n",
    "            patience=model_configs.get(\"patience\", None),\n",
    "            max_epochs=model_configs.get(\"max_epochs\", None),\n",
    "        )\n",
    "\n",
    "    eval_mask = torch.zeros(dataset_info.n_nodes, dtype=torch.bool)\n",
    "    eval_mask[test_idx] = True\n",
    "    if not inductive:\n",
    "        eval_mask[unlabeled_idx] = True\n",
    "    acc = accuracy(model, test_attr, test_adj, test_dataset.y.to(device), eval_mask)\n",
    "\n",
    "    model_storage_name = model_storage_label(\n",
    "        model_name=model_name, \n",
    "        model_params=model_configs, \n",
    "        dataset_info=dataset_info, \n",
    "        inductive=inductive, \n",
    "        split_name=split_name)\n",
    "    try:\n",
    "        if save_models:\n",
    "            os.makedirs(models_root, exist_ok=True)\n",
    "            torch.save(model.state_dict(), os.path.join(models_root, f\"{model_storage_name}.pt\"))\n",
    "    except RuntimeError as e:\n",
    "        print(f\"Error saving the model: {e}\")\n",
    "\n",
    "    model_instance = {\n",
    "        \"model\": model,\n",
    "        \"model_configs\": model_configs, \n",
    "        \"accuracy\": acc,\n",
    "        \"model_storage_name\": model_storage_name,\n",
    "    }\n",
    "    acc = model_instance[\"accuracy\"]\n",
    "    model_configs = model_instance[\"model_configs\"]\n",
    "    model_storage_name = model_instance[\"model_storage_name\"]\n",
    "    print(f\"Model's accuracy: {acc} -- stored under the name: {model_storage_name}\")\n",
    "    accs.append(acc)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 61,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Average accuracy: 0.6021352313167261, with standard deviation: 0.05220533162355423\n"
     ]
    }
   ],
   "source": [
    "print(f\"Average accuracy: {sum(accs)/len(accs)}, with standard deviation: {torch.std(torch.tensor(accs))}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[0.6681887366818873,\n",
       " 0.7047184170471841,\n",
       " 0.4916286149162861,\n",
       " 0.7148655504819887,\n",
       " 0.6078132927447996]"
      ]
     },
     "execution_count": 55,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "accs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
