{"cells":[{"cell_type":"code","execution_count":1,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":18630,"status":"ok","timestamp":1693488676785,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"92oZp-LoB45N","outputId":"15d1190b-4914-4162-c3db-1139495d6d62"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":3,"metadata":{"code_folding":[0],"executionInfo":{"elapsed":5142,"status":"ok","timestamp":1693488686817,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"DpMP2f_OBtNQ"},"outputs":[],"source":["# Imports\n","import torch\n","cuda = torch.cuda.is_available()\n","import numpy as np\n","import matplotlib.pyplot as plt\n","%matplotlib inline\n","import sys\n","sys.path.append(\"../../semi-supervised\")\n","sys.path.append(\"../../\")"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["#### pubmed\n","\n","import importlib\n","import datautils\n","import datasets\n","importlib.reload(datautils)\n","importlib.reload(datasets)\n","from datautils import get_data_krylov\n","\n","labelled, unlabelled, validation, test = get_data_krylov(batch_size=128,\n","                                                                    labels_per_class=977,\n","                                                                    validation_ratio=0.25,\n","                                                                    dataset_name='pubmed',\n","                                                                    norm_flag=False,\n","                                                                    gamma=1,\n","                                                                    dataset_path='data/',\n","                                                                    n_workers = 0)\n","\n","alpha = 0.01 * (len(unlabelled) + len(labelled)) / len(labelled)\n","print(alpha)\n","def binary_cross_entropy(r, x):\n","    return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1)\n"]},{"cell_type":"code","execution_count":8,"metadata":{"executionInfo":{"elapsed":4,"status":"ok","timestamp":1693489208792,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"R8yytVqoBtNW"},"outputs":[],"source":["###model definition\n","import models\n","importlib.reload(models)\n","from models import AuxiliaryDeepGenerativeModelDropout\n","\n","\n","y_dim = 5\n","n_feature = 932\n","z_dim = 50\n","a_dim = 50\n","h_dim = [512, 256]\n","\n","\n","dropout_rate = 0.3\n","\n","###model definition\n","from models import AuxiliaryDeepGenerativeModelDropout\n","\n","model = AuxiliaryDeepGenerativeModelDropout([n_feature, y_dim, z_dim, a_dim, h_dim],dropout_rate=dropout_rate)\n","model\n","\n","\n","optimizer = torch.optim.Adam(model.parameters(), lr=2e-4, betas=(0.9, 0.999))\n","\n","from itertools import cycle\n","from inference import SVI, DeterministicWarmup\n","\n","# We will need to use warm-up in order to achieve good performance.\n","# Over 200 calls to SVI we change the autoencoder from\n","# deterministic to stochastic.\n","beta = DeterministicWarmup(n=200)\n","\n","\n","if cuda: model = model.cuda()\n","elbo = SVI(model, likelihood=binary_cross_entropy, beta=beta)"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["from torch.autograd import Variable\n","import random\n","from itertools import cycle\n","\n","#pLsq = graph_unlabelled_data.pLsq\n","#Y = graph_unlabelled_data.Y\n","#print(Y.shape)\n","\n","#if cuda:\n","#  pLsq = pLsq.cuda(device=0)\n","#  Y = Y.cuda(device=0)\n","\n","max_accuracy_valid = 0\n","max_accuracy_test = 0\n","\n","for epoch in range(200):\n","    model.train()\n","    total_loss, accuracy = (0, 0)\n","    accuracy_ = 0\n","    xx = 0\n","    uu = 0\n","    mm = 0\n","    for (x, y), (u, yy) in zip(cycle(labelled), unlabelled):\n","        # Wrap in variables\n","        x, y, u = Variable(x), Variable(y), Variable(u)\n","        yy = Variable(yy)\n","\n","        if cuda:\n","            # They need to be on the same device and be synchronized.\n","            x, y = x.cuda(device=0), y.cuda(device=0)\n","            u, yy = u.cuda(device=0), yy.cuda(device=0)\n","\n","\n","\n","        vat_loss = VATLoss(xi=xi, eps=eps, ip=ip)\n","        cross_entropy = nn.CrossEntropyLoss()\n","\n","        lds = vat_loss(model, u)\n","        output = model(x)\n","        classification_loss = cross_entropy(output, y)\n","\n","        J_alpha = lds + alpha * classification_loss\n","        #from IPython.core.debugger import Pdb; Pdb().set_trace()\n","\n","        J_alpha.backward()\n","        optimizer.step()\n","        #optimizer.zero_grad()\n","\n","        total_loss += J_alpha.item()\n","        accuracy += torch.eq(torch.max(output, dim=1)[1], torch.max(y, dim=1)[1]).detach().cpu().float().sum()\n","        mm += x.shape[0]\n","        # Add auxiliary classification loss q(y|x)\n","        #logits_ = model.classify(u)\n","        #print(yy)\n","        #accuracy_ += torch.mean((torch.max(logits_, 1)[1].data == torch.max(yy, 1)[1].data).float())\n","        #print(torch.max(yy, 1)[1].data)\n","\n","    if epoch % 1 == 0:\n","\n","        print(\"Epoch: {}\".format(epoch))\n","        print(\"[Train]\\t\\t J_a: {:.2f}, accuracy: {:.4f}\".format(total_loss / mm, accuracy / mm))\n","\n","        if accuracy / mm > 0.93:\n","          alpha = 0.01 * (len(unlabelled) + len(labelled)) / len(labelled)\n","\n","        total_loss, accuracy_valid = (0, 0)\n","        m = len(validation)\n","        for x, y in validation:\n","\n","            x, y = Variable(x), Variable(y)\n","            if cuda:\n","                x, y = x.cuda(device=0), y.cuda(device=0)\n","            outputs = model(x)\n","            #L = -elbo(x, y)\n","            #U = -elbo(x)\n","\n","            outputs = model(x)\n","            accuracy_valid += torch.eq(torch.max(outputs, dim=1)[1], torch.max(y, dim=1)[1]).detach().cpu().float().mean() / m\n","\n","\n","        print(\"[Validation]\\t accuracy: {:.4f}\".format(accuracy_valid))\n","\n","\n","        total_loss, accuracy_test = (0, 0)\n","        for x, y in test:\n","            m = len(test)\n","\n","            with torch.no_grad():\n","                x, y = Variable(x), Variable(y)\n","                if cuda:\n","                  x, y = x.cuda(device=0), y.cuda(device=0)\n","                outputs = model(x)\n","\n","            accuracy_test += torch.eq(torch.max(outputs, dim=1)[1], torch.max(y, dim=1)[1]).detach().cpu().float().mean() / m\n","\n","        if max_accuracy_valid < accuracy_valid:\n","            max_accuracy_valid = accuracy_valid\n","            max_accuracy_test = accuracy_test\n","        if max_accuracy_valid == accuracy_valid and max_accuracy_test < accuracy_test:\n","            max_accuracy_test = accuracy_test\n","\n","        m = len(test)\n","        print(\"[Test]\\t accuracy_current: {:.4f}, accuracy_max: {:.4f}\".format(accuracy_test, max_accuracy_test))\n","\n"]}],"metadata":{"accelerator":"GPU","colab":{"gpuType":"V100","machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":0}
