{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1c680e41-57da-4ded-9296-b32216d35174",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Loading\n",
    "import os.path as osp\n",
    "\n",
    "# Pytorch Module\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv,GATConv, SAGEConv, Sequential\n",
    "from torch_geometric.data import Dataset, download_url,DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "\n",
    "# Optuna Module\n",
    "import optuna\n",
    "from optuna.trial import TrialState\n",
    "from optuna.importance import get_param_importances\n",
    "\n",
    "# F1-Score\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Parsing\n",
    "import argparse"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "e453c135-bdcf-4160-bc7f-64d58fb3f292",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define the Class to load the data\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "5aad6d39-5b15-436b-b0ee-cdad624f568b",
   "metadata": {},
   "outputs": [],
   "source": [
    "# F1 loss\n",
    "class F1_Loss(nn.Module):\n",
    "    '''Calculate F1 score. Can work with gpu tensors\n",
    "    Returns\n",
    "    -------\n",
    "    torch.Tensor\n",
    "        `ndim` == 1. epsilon <= val <= 1\n",
    "    '''\n",
    "    \n",
    "    def __init__(self, epsilon=1e-6):\n",
    "        super().__init__()\n",
    "        self.epsilon = epsilon\n",
    "        \n",
    "    def forward(self, y_pred, y_true,):\n",
    "        assert y_pred.ndim == 2\n",
    "        assert y_true.ndim == 1\n",
    "        y_true = F.one_hot(y_true, 3).to(torch.float32)\n",
    "        y_pred = F.softmax(y_pred, dim=1)\n",
    "        \n",
    "        tp = (y_true * y_pred).sum(dim=0).to(torch.float32)\n",
    "        tn = ((1 - y_true) * (1 - y_pred)).sum(dim=0).to(torch.float32)\n",
    "        fp = ((1 - y_true) * y_pred).sum(dim=0).to(torch.float32)\n",
    "        fn = (y_true * (1 - y_pred)).sum(dim=0).to(torch.float32)\n",
    "\n",
    "        precision = tp / (tp + fp + self.epsilon)\n",
    "        recall = tp / (tp + fn + self.epsilon)\n",
    "\n",
    "        f1 = 2* (precision*recall) / (precision + recall + self.epsilon)\n",
    "        f1 = f1.clamp(min=self.epsilon, max=1-self.epsilon)\n",
    "        return 1 - f1.mean()\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "c0052a0a-9969-47ff-9b8f-b66aa55d2609",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Sequential Network\n",
    "def define_model_GATConv(trial, CLASSES,NODE_FEATURES):\n",
    "    n_layers = trial.suggest_int(\"n_layers\", 1, 5)\n",
    "    layers = []\n",
    "    in_features = NODE_FEATURES\n",
    "    \n",
    "    out_features = trial.suggest_int(\"n_units_l{}\".format(0), 10, 500)\n",
    "    heads = trial.suggest_int(\"head_l{}\".format(0), 1, 6)\n",
    "    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )\n",
    "    in_features = out_features*heads\n",
    "    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )\n",
    "\n",
    "    for i in range(1,n_layers):\n",
    "        out_features = trial.suggest_int(\"n_units_l{}\".format(i), 10, 500)\n",
    "        heads = trial.suggest_int(\"head_l{}\".format(i), 1, 6)\n",
    "        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        p = trial.suggest_float(\"dropout_l{}\".format(i), 0.01, 0.2)\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features*heads\n",
    "    \n",
    "    n_layers_linear = trial.suggest_int(\"n_layers_linear\", 0, 4)\n",
    "    for i in range(1,n_layers_linear):\n",
    "        out_features = trial.suggest_int(\"n_units_lin_l{}\".format(i), 4, 250)\n",
    "        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        in_features = out_features\n",
    "        \n",
    "    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))\n",
    "    return Sequential('x, edge_index',layers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "8fb98edb-f715-41fb-a517-9753c297f968",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_model_SAGEConv(trial, CLASSES,NODE_FEATURES):\n",
    "    n_layers = trial.suggest_int(\"n_layers\", 1, 6)\n",
    "    layers = []\n",
    "    in_features = NODE_FEATURES\n",
    "    \n",
    "    out_features = trial.suggest_int(\"n_units_l{}\".format(0), 4, 500)\n",
    "    layers.append( (SAGEConv(in_features, out_features), 'x, edge_index -> x') )\n",
    "    in_features = out_features\n",
    "    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )\n",
    "\n",
    "    \n",
    "    for i in range(1,n_layers):\n",
    "        out_features = trial.suggest_int(\"n_units_l{}\".format(i), 4, 500)\n",
    "        layers.append((SAGEConv(in_features, out_features), 'x, edge_index -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        p = trial.suggest_float(\"dropout_l{}\".format(i), 0.01, 0.2)\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features\n",
    "    \n",
    "    n_layers_linear = trial.suggest_int(\"n_layers_linear\", 0, 4)\n",
    "    for i in range(1,n_layers_linear):\n",
    "        out_features = trial.suggest_int(\"n_units_lin_l{}\".format(i), 4, 250)\n",
    "        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        in_features = out_features\n",
    "        \n",
    "    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))\n",
    "    return Sequential('x, edge_index',layers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "546d567c-ba1b-4034-b8f0-03b5933a7548",
   "metadata": {},
   "outputs": [],
   "source": [
    "def define_model_GCNConv(trial, CLASSES,NODE_FEATURES):\n",
    "    n_layers = trial.suggest_int(\"n_layers\", 1, 6)\n",
    "    layers = []\n",
    "    in_features = NODE_FEATURES\n",
    "    \n",
    "    out_features = trial.suggest_int(\"n_units_l{}\".format(0), 4, 500)\n",
    "    layers.append( (GCNConv(in_features, out_features), 'x, edge_index -> x') )\n",
    "    in_features = out_features\n",
    "    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )\n",
    "    \n",
    "    for i in range(1,n_layers):\n",
    "        out_features = trial.suggest_int(\"n_units_l{}\".format(i), 4, 250)\n",
    "        layers.append((GCNConv(in_features, out_features), 'x, edge_index -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        p = trial.suggest_float(\"dropout_l{}\".format(i), 0.001, 0.2)\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features\n",
    "    \n",
    "    n_layers_linear = trial.suggest_int(\"n_layers_linear\", 0, 4)\n",
    "    for i in range(1,n_layers_linear):\n",
    "        out_features = trial.suggest_int(\"n_units_lin_l{}\".format(i), 4, 250)\n",
    "        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        in_features = out_features\n",
    "        \n",
    "    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))\n",
    "    return Sequential('x, edge_index',layers)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "6f051c31-8eb9-4def-8bd4-ef5f3f9e1c3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define The optuna objective function\n",
    "def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):\n",
    "    \n",
    "    #f1_loss = F1_Loss().to(DEVICE)\n",
    "    # Choose the model.\n",
    "    classifier_name = trial.suggest_categorical('classifier',['GAT']) #'GCN','Sage',\n",
    "    if(classifier_name == 'GCN'):\n",
    "        model = define_model_GCNConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)\n",
    "    elif(classifier_name == 'Sage'):\n",
    "        model = define_model_SAGEConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)\n",
    "    else:\n",
    "        model = define_model_GATConv(trial, CLASSES,NODE_FEATURES).to(DEVICE)\n",
    "    # Generate the optimizers.\n",
    "    optimizer_name = trial.suggest_categorical(\"optimizer\", [\"AdamW\"])\n",
    "    lr = trial.suggest_float(\"lr\", 1e-5, 1e-3, log=True)\n",
    "    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)\n",
    "    \n",
    "    lmbda = lambda epoch: 0.99\n",
    "    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)\n",
    "    max_grad_norm = trial.suggest_float(\"max_grad_norm\", 1e-1, 1e+1)\n",
    "    # Training of the model.\n",
    "    val_to_replace = torch.tensor(-1e9, dtype=torch.float32).to(DEVICE)\n",
    "    for epoch in range(EPOCHS):\n",
    "        model.train()\n",
    "        for batch_idx, (ele) in enumerate(train_loader):\n",
    "            optimizer.zero_grad()\n",
    "            ele = ele.to(DEVICE)\n",
    "            ele[\"x\"] = ele[\"x\"].float()\n",
    "            ele[\"y\"] = ele[\"y\"].long()\n",
    "            #ele[\"edge_index\"] =  to_undirected(ele[\"edge_index\"])\n",
    "            out = model(ele.x, ele.edge_index)\n",
    "            \n",
    "            #ele_y_clean = torch.where(ele[\"indice_nt_tensor\"] > 0, ele[\"y\"], 10)\n",
    "            #ele_y_clean = ele_y_clean[ele_y_clean!=10]\n",
    "            #ele_batch_matrix = ele[\"indice_nt_tensor\"].repeat(3, 1).transpose(0,1)\n",
    "            #out_clean = torch.where(ele_batch_matrix.to(torch.float32) > 0, out, val_to_replace)\n",
    "            #out_clean = out_clean[out_clean!=10].view(-1,3)\n",
    "            \n",
    "            loss = nn.CrossEntropyLoss()\n",
    "            #loss_calc = f1_loss(out, ele[\"y\"])\n",
    "            loss_calc = loss(out, ele[\"y\"])\n",
    "            loss_calc.backward()\n",
    "            # gradient clipping\n",
    "            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "            optimizer.step()\n",
    "        # Validation of the model.\n",
    "        model.eval()\n",
    "        total_acc, total_count =  0, 0\n",
    "        total_class_predict_token = []\n",
    "        total_true_class_token = []\n",
    "        with torch.no_grad():\n",
    "            for batch_idx, (ele) in enumerate(dev_loader):\n",
    "                # Limiting validation data.\n",
    "                ele = ele.to(DEVICE)\n",
    "                ele[\"x\"] = ele[\"x\"].float()\n",
    "                ele[\"y\"] = ele[\"y\"].long()\n",
    "                #ele[\"edge_index\"] =  to_undirected(ele[\"edge_index\"])\n",
    "                out = model(ele.x, ele.edge_index)\n",
    "                \n",
    "                ele_y_clean = torch.where(ele[\"indice_nt_tensor\"] > 0, ele[\"y\"], -10000)\n",
    "                ele_y_clean = ele_y_clean[ele_y_clean!=-10000]\n",
    "\n",
    "                ele_batch_matrix = ele[\"indice_nt_tensor\"].repeat(3, 1).transpose(0,1)\n",
    "                out_clean = torch.where(ele_batch_matrix.to(torch.float32) > 0, out, val_to_replace)\n",
    "                out_clean = out_clean[out_clean!=-1e9].view(-1,3)\n",
    "                class_predict_clean = out_clean.argmax(1).cpu()\n",
    "                \n",
    "                #total_acc += (out_clean.argmax(1) == ele_y_clean).sum().item()\n",
    "                #total_count += ele_y_clean.size(0)\n",
    "                \n",
    "                total_true_class_token.extend(ele_y_clean.tolist())\n",
    "            \n",
    "                total_class_predict_token.extend(class_predict_clean.tolist())\n",
    "                \n",
    "        #accuracy = total_acc/total_count\n",
    "        f1_score_token = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"macro\")\n",
    "        trial.report(f1_score_token, epoch)\n",
    "        # Handle pruning based on the intermediate value.\n",
    "        #if trial.should_prune() :\n",
    "        #    raise optuna.exceptions.TrialPruned()\n",
    "        scheduler.step()\n",
    "    return f1_score_token\n",
    "\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e44ec0f6-4462-4b76-85bf-5d91ceef3e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the main function to call by script\n",
    "def main():\n",
    "    \n",
    "    batch_size = 150\n",
    "    num_trials = 200\n",
    "    #data_dir = '../../data/aurc/bert/large_depth_all_connected'\n",
    "    data_dir = '../../data/aurc/bert/large_depth_IN_connected' \n",
    "    EPOCHS = 15\n",
    "    NODE_FEATURES = 1024\n",
    "    DEVICE = torch.device('cuda:2' if torch.cuda.is_available() else 'cpu')\n",
    "    CLASSES = 3\n",
    "    timeout = 3600*8\n",
    "    num_data_train = 3960\n",
    "    num_data_dev = 790\n",
    "    num_data_test = 1959\n",
    "    \n",
    "    train_dataset = MyOwnDataset(num_data = num_data_train,root = data_dir+\"/Train\")\n",
    "    test_dataset = MyOwnDataset(num_data = num_data_test,root = data_dir+\"/Test\")\n",
    "    dev_dataset = MyOwnDataset(num_data = num_data_dev,root = data_dir+\"/Dev\")\n",
    "    \n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size)\n",
    "    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "    \n",
    "    ## Initialize the optuna research\n",
    "    study = optuna.create_study(\n",
    "        direction=\"maximize\",\n",
    "        study_name = \"Trained_BERT_GNN_IN_connected\", \n",
    "        storage=\"sqlite:///../optuna_db/Trained_BERT_large_GNN_IN_connected.db\",\n",
    "        load_if_exists=True)\n",
    "    \n",
    "    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=num_trials, timeout=timeout)\n",
    "\n",
    "    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])\n",
    "    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])\n",
    "\n",
    "    print(\"Study statistics: \")\n",
    "    print(\"  Number of finished trials: \", len(study.trials))\n",
    "    print(\"  Number of pruned trials: \", len(pruned_trials))\n",
    "    print(\"  Number of complete trials: \", len(complete_trials))\n",
    "\n",
    "    print(\"Best trial:\")\n",
    "    trial = study.best_trial\n",
    "\n",
    "    print(\"  Value: \", trial.value)\n",
    "\n",
    "    print(\"  Params: \")\n",
    "    for key, value in trial.params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n",
    "    \n",
    "    print(\"Name of the study : \" + study.study_name)\n",
    "    \n",
    "    print(\"  Params Importance: \")\n",
    "    dict_params = get_param_importances(study)\n",
    "    for key, value in dict_params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93139a9-6688-4590-9429-dedc2af9be93",
   "metadata": {},
   "outputs": [],
   "source": [
    "main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
