{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "1c680e41-57da-4ded-9296-b32216d35174",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Data Loading\n",
    "import os.path as osp\n",
    "\n",
    "# Pytorch Module\n",
    "import torch\n",
    "import torch.optim as optim\n",
    "from torch import nn\n",
    "import torch.nn.functional as F\n",
    "from torch_geometric.nn import GCNConv,GATConv, SAGEConv, Sequential\n",
    "from torch_geometric.data import Dataset, download_url,DataLoader\n",
    "from torch_geometric.utils import to_undirected\n",
    "\n",
    "# Optuna Module\n",
    "import optuna\n",
    "from optuna.trial import TrialState\n",
    "from optuna.importance import get_param_importances\n",
    "\n",
    "# F1-Score\n",
    "from sklearn.metrics import f1_score\n",
    "\n",
    "# Parsing\n",
    "import argparse\n",
    "\n",
    "import sys\n",
    "sys.path.insert(0, '../../utils/')\n",
    "from utils_GNN_CRF import (\n",
    "    compute_tab_logits_pred_sentence_level,\n",
    "    extract_sentence_level_label,\n",
    "    F1_Loss,\n",
    "    find_max_size_batch_without_nt,\n",
    "    remove_internal_nodes,\n",
    "    remove_internal_nodes_without_labels,\n",
    "    reshape_label_to_batch_padded,\n",
    "    reshape_mask_to_batch_padded,\n",
    "    reshape_data_to_batch_padded,\n",
    "    find_the_most_frequent_element\n",
    ")\n",
    "\n",
    "from torchcrf import CRF"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1629c9b8-5973-40c6-abf2-69f73073e1f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "f1db22fe-a623-4874-95c0-df6b4e1b7779",
   "metadata": {},
   "outputs": [],
   "source": [
    "## Define the Class to load the data\n",
    "class MyOwnDataset(Dataset):\n",
    "    def __init__(self, num_data, root, transform=None, pre_transform=None):\n",
    "        super(MyOwnDataset, self).__init__(root, transform, pre_transform)\n",
    "        self.num_data = num_data\n",
    "    @property\n",
    "    def raw_file_names(self):\n",
    "        return []\n",
    "    @property\n",
    "    def processed_file_names(self):\n",
    "        return ['data_{}.pt'.format(idx) for idx in range(self.num_data)]\n",
    "    def len(self):\n",
    "        return len(self.processed_file_names)\n",
    "    def get(self, idx):\n",
    "        data = torch.load(osp.join(self.processed_dir, 'data_{}.pt'.format(idx)))\n",
    "        return data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "fd89b112-d609-4a5d-a584-be1aec7d9116",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the Sequential Network\n",
    "def define_model_GATConv(trial, CLASSES,NODE_FEATURES):\n",
    "    n_layers = trial.suggest_int(\"n_layers\", 1, 5)\n",
    "    layers = []\n",
    "    in_features = NODE_FEATURES\n",
    "\n",
    "    out_features = trial.suggest_int(\"n_units_l{}\".format(0), 10, 100)\n",
    "    heads = trial.suggest_int(\"head_l{}\".format(0), 1, 3)\n",
    "    layers.append( (GATConv(in_features, out_features, heads), 'x, edge_index -> x') )\n",
    "    in_features = out_features*heads\n",
    "    layers.append( (lambda x, edge_index: (F.relu(x), to_undirected(edge_index)) , 'x, edge_index -> x, edge_index') )\n",
    "\n",
    "    for i in range(1,n_layers):\n",
    "        out_features = trial.suggest_int(\"n_units_l{}\".format(i), 10, 100)\n",
    "        heads = trial.suggest_int(\"head_l{}\".format(i), 1, 3)\n",
    "        layers.append((GATConv(in_features, out_features, heads), 'x, edge_index -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        p = trial.suggest_float(\"dropout_l{}\".format(i), 0.01, 0.2)\n",
    "        layers.append(nn.Dropout(p))\n",
    "        in_features = out_features*heads\n",
    "\n",
    "    n_layers_linear = trial.suggest_int(\"n_layers_linear\", 0, 2)\n",
    "    for i in range(1,n_layers_linear):\n",
    "        out_features = trial.suggest_int(\"n_units_lin_l{}\".format(i), 4, 200)\n",
    "        layers.append(( nn.Linear(in_features, out_features), 'x -> x'))\n",
    "        layers.append(nn.ReLU(inplace=True))\n",
    "        in_features = out_features\n",
    "\n",
    "    layers.append(( nn.Linear(in_features, CLASSES), 'x -> x'))\n",
    "    return Sequential('x, edge_index',layers)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "742f70ba-3e8b-45b2-978d-2edf82e31df2",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "df93f109-488a-472a-a72e-a0a40f9e682e",
   "metadata": {},
   "outputs": [],
   "source": [
    "class Net_large_CRF(torch.nn.Module):\n",
    "    def __init__(self,trial,num_node_features, num_labels = 3, batch_first=True):\n",
    "        super(Net_large_CRF, self).__init__()\n",
    "        \n",
    "        self.num_labels = num_labels\n",
    "        self.batch_first = batch_first\n",
    "        \n",
    "        self.GNN_attention = define_model_GATConv(trial, self.num_labels, num_node_features)\n",
    "        \n",
    "        self.crf = CRF(self.num_labels, batch_first=self.batch_first)\n",
    "\n",
    "    def forward(self, data, labels=None):\n",
    "        max_size_actu = 62\n",
    "        x, edge_index, batch, indice_nt_tensor = data.x, data.edge_index, data.batch, data.indice_nt_tensor\n",
    "        \n",
    "        out = self.GNN_attention(x, edge_index)\n",
    "        \n",
    "        if(labels is not None):\n",
    "            sample_x_new,sample_batch_new,sample_label_new = remove_internal_nodes(indice_nt_tensor,out,batch,labels)\n",
    "            labels, padding_mask = reshape_label_to_batch_padded(max_size_actu,sample_batch_new,sample_label_new)\n",
    "            logits = reshape_data_to_batch_padded(max_size_actu,sample_batch_new,sample_x_new,3)\n",
    "            return -self.crf(emissions = logits, tags=labels,mask=padding_mask.byte())\n",
    "        else:\n",
    "            sample_x_new,sample_batch_new = remove_internal_nodes_without_labels(indice_nt_tensor,out,batch)\n",
    "            \n",
    "            padding_mask = reshape_mask_to_batch_padded(max_size_actu,sample_batch_new)\n",
    "            \n",
    "            logits = reshape_data_to_batch_padded(max_size_actu,sample_batch_new,sample_x_new,3)\n",
    "            return self.crf.decode(emissions =logits, mask=padding_mask.byte())\n",
    "        \n",
    "        return x\n",
    "    \n",
    "    \n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "6f051c31-8eb9-4def-8bd4-ef5f3f9e1c3a",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define The optuna objective function\n",
    "def objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader):\n",
    "    \n",
    "    # Choose the model.\n",
    "    model = Net_large_CRF(trial, NODE_FEATURES).to(DEVICE)\n",
    "    \n",
    "    # Generate the optimizers.\n",
    "    optimizer_name = trial.suggest_categorical(\"optimizer\", [\"AdamW\"])\n",
    "    lr = trial.suggest_float(\"lr\", 1e-5, 1e-3, log=True)\n",
    "    optimizer = getattr(optim, optimizer_name)(model.parameters(), lr=lr)\n",
    "    \n",
    "    lmbda = lambda epoch: 0.99\n",
    "    scheduler =  torch.optim.lr_scheduler.MultiplicativeLR(optimizer, lr_lambda=lmbda)\n",
    "    max_grad_norm = trial.suggest_float(\"max_grad_norm\", 1e-1, 1e+1)\n",
    "    \n",
    "    # Training of the model.\n",
    "    for epoch in range(EPOCHS):\n",
    "        model.train()\n",
    "        for idx, (ele) in enumerate(train_loader):\n",
    "            ele = ele.to(DEVICE)\n",
    "            ele[\"x\"] = ele[\"x\"].float()\n",
    "            optimizer.zero_grad()\n",
    "            loss = model(ele,ele[\"y\"])\n",
    "\n",
    "            loss.backward()\n",
    "            # gradient clipping\n",
    "            torch.nn.utils.clip_grad_norm_(parameters=model.parameters(), max_norm=max_grad_norm)\n",
    "            optimizer.step()\n",
    "        \n",
    "        model.eval()\n",
    "        # initalize the variable to compute the F1 score at token level\n",
    "        total_class_predict_token = []\n",
    "        total_true_class_token = []\n",
    "\n",
    "        with torch.no_grad():\n",
    "            for idx, (ele) in enumerate(dev_loader):\n",
    "                ele = ele.to(DEVICE)\n",
    "                ele[\"x\"] = ele[\"x\"].float()\n",
    "                class_predict = model(ele)\n",
    "\n",
    "                ## Flatten the result\n",
    "                class_predict_flat = [item for sublist in class_predict for item in sublist]\n",
    "                total_class_predict_token.extend(class_predict_flat)\n",
    "\n",
    "                ## Extract the true label without the internal nodes\n",
    "                ele_y_clean = torch.where(ele[\"indice_nt_tensor\"] > 0, ele[\"y\"], 10)\n",
    "                ele_y_clean = ele_y_clean[ele_y_clean!=10].tolist()\n",
    "                total_true_class_token.extend(ele_y_clean)\n",
    "                \n",
    "\n",
    "            f1_token_level = f1_score(total_class_predict_token,total_true_class_token,labels = [0,1,2], average=\"macro\")\n",
    "            trial.report(f1_token_level, epoch)\n",
    "            scheduler.step()\n",
    "\n",
    "    return f1_token_level"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "e44ec0f6-4462-4b76-85bf-5d91ceef3e35",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define the main function to call by script\n",
    "def main():\n",
    "    \n",
    "    batch_size = 190\n",
    "    num_trials = 200\n",
    "    #data_dir = '../../data/aurc/bert/large_depth_all_connected'\n",
    "    data_dir = '../../data/aurc/bert/large_depth_IN_connected' \n",
    "    EPOCHS = 15\n",
    "    NODE_FEATURES = 1024\n",
    "    DEVICE = \"cpu\"#torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')\n",
    "    CLASSES = 3\n",
    "    timeout = 3600*8\n",
    "    num_data_train = 3960\n",
    "    num_data_dev = 790\n",
    "    num_data_test = 1959\n",
    "    \n",
    "    ## Load the Data\n",
    "    #train_dataset = MyOwnDataset(num_data = args.num_data_train,root = data_dir+\"/Train\")\n",
    "    #test_dataset = MyOwnDataset(num_data = args.num_data_test,root = data_dir+\"/Test\")\n",
    "    #dev_dataset = MyOwnDataset(num_data = args.num_data_dev,root = data_dir+\"/Dev\")\n",
    "    \n",
    "    train_dataset = MyOwnDataset(num_data = num_data_train,root = data_dir+\"/Train\")\n",
    "    test_dataset = MyOwnDataset(num_data = num_data_test,root = data_dir+\"/Test\")\n",
    "    dev_dataset = MyOwnDataset(num_data = num_data_dev,root = data_dir+\"/Dev\")\n",
    "    \n",
    "    \n",
    "    train_loader = DataLoader(train_dataset, batch_size=batch_size)\n",
    "    dev_loader = DataLoader(dev_dataset, batch_size=batch_size)\n",
    "    test_loader = DataLoader(test_dataset, batch_size=batch_size)\n",
    "    \n",
    "    ## Initialize the optuna research\n",
    "    study = optuna.create_study(\n",
    "        direction=\"maximize\",\n",
    "        study_name = \"Trained_BERT_GNN_CRF_IN_connected\", \n",
    "        storage=\"sqlite:///../optuna_db/Trained_BERT_large_GNN_CRF_IN_connected.db\",\n",
    "        load_if_exists=True)\n",
    "    \n",
    "    study.optimize(lambda trial: objective(trial, DEVICE, EPOCHS, CLASSES, NODE_FEATURES, train_loader, dev_loader, test_loader), n_trials=num_trials, timeout=timeout)\n",
    "\n",
    "    pruned_trials = study.get_trials(deepcopy=False, states=[TrialState.PRUNED])\n",
    "    complete_trials = study.get_trials(deepcopy=False, states=[TrialState.COMPLETE])\n",
    "\n",
    "    print(\"Study statistics: \")\n",
    "    print(\"  Number of finished trials: \", len(study.trials))\n",
    "    print(\"  Number of pruned trials: \", len(pruned_trials))\n",
    "    print(\"  Number of complete trials: \", len(complete_trials))\n",
    "\n",
    "    print(\"Best trial:\")\n",
    "    trial = study.best_trial\n",
    "\n",
    "    print(\"  Value: \", trial.value)\n",
    "\n",
    "    print(\"  Params: \")\n",
    "    for key, value in trial.params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n",
    "    \n",
    "    print(\"Name of the study : \" + study.study_name)\n",
    "    \n",
    "    print(\"  Params Importance: \")\n",
    "    dict_params = get_param_importances(study)\n",
    "    for key, value in dict_params.items():\n",
    "        print(\"    {}: {}\".format(key, value))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b93139a9-6688-4590-9429-dedc2af9be93",
   "metadata": {},
   "outputs": [],
   "source": [
    "main()"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "concept_env",
   "language": "python",
   "name": "concept_env"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
