{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":24216,"status":"ok","timestamp":1694005559248,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"92oZp-LoB45N","outputId":"a9dc5de9-4933-4519-cbea-2a2d5f3a8472"},"outputs":[{"name":"stdout","output_type":"stream","text":["Mounted at /content/drive\n"]}],"source":["from google.colab import drive\n","drive.mount('/content/drive')"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":1667,"status":"ok","timestamp":1694005562539,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"-XX6SYZsB6mj","outputId":"d292d773-1e14-4536-d2af-7779e39439da"},"outputs":[],"source":["import os\n","print(\"current working directiry is [\" + os.getcwd() + \"]\")\n","_colab_dir = \"/content/drive/MyDrive/Colab Notebooks/ResTran/examples/notebooks\"\n","os.chdir(_colab_dir)\n","print(\"current working directiry is [\" + os.getcwd() + \"]\")"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"DpMP2f_OBtNQ"},"outputs":[],"source":["# Imports\n","import torch\n","cuda = torch.cuda.is_available()\n","import numpy as np\n","import matplotlib.pyplot as plt\n","%matplotlib inline\n","import sys\n","sys.path.append(\"../../semi-supervised\")\n","sys.path.append(\"../../\")\n"]},{"cell_type":"code","execution_count":null,"metadata":{"id":"QBDVgB5qmNc2"},"outputs":[],"source":["import torch.nn as nn\n","import torch.nn.functional as F\n","\n","class NN(nn.Module):\n","    def __init__(self, dims, dropout_rate):\n","\n","        super(NN, self).__init__()\n","\n","        [x_dim, y_dim, h_dim] = dims\n","        self.dropout_rate = dropout_rate\n","        neurons = [x_dim, *h_dim, y_dim]\n","        linear_layers = [nn.Linear(neurons[i-1], neurons[i]) for i in range(1, len(neurons))]\n","\n","        self.hidden = nn.ModuleList(linear_layers)\n","\n","    def forward(self, x):\n","        for layer in self.hidden:\n","            x = F.relu(layer(x))\n","            x = F.dropout(x, p=self.dropout_rate, training=self.training)\n","        return x"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":462,"status":"ok","timestamp":1694010980779,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"ilQ0gO9QBtNV","outputId":"8501cf1c-9ec1-43ca-d203-a4a07fab6e53"},"outputs":[],"source":["#### data\n","\n","import importlib\n","import datautils\n","import datasets\n","importlib.reload(datautils)\n","importlib.reload(datasets)\n","from datautils import get_data_krylov\n","\n","# Only use 10 labelled examples per class\n","# The rest of the data is unlabelled.\n","labelled, unlabelled, validation, test = get_data_krylov(batch_size=128,\n","                                                                    labels_per_class=977,\n","                                                                    validation_ratio=0.25,\n","                                                                    dataset_name='pubmed',\n","                                                                    norm_flag=False,\n","                                                                    gamma=1,\n","                                                                    b = 0,\n","                                                                    #dataset_path='/content/drive/My Drive/Colab Notebooks/ssldeep/semi-supervised-pytorch-master/examples/notebooks/data')\n","                                                                    dataset_path='data/',\n","                                                                    n_workers = 0)\n","\n","alpha = 0.1 * (len(unlabelled) + len(labelled)) / len(labelled)\n","print(alpha)\n","def binary_cross_entropy(r, x):\n","    return -torch.sum(x * torch.log(r + 1e-8) + (1 - x) * torch.log(1 - r + 1e-8), dim=-1)\n"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/"},"executionInfo":{"elapsed":10,"status":"ok","timestamp":1694011076530,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"R8yytVqoBtNW","outputId":"81ace778-b7a4-4783-8db2-7bcf4f2cc8a5"},"outputs":[{"name":"stdout","output_type":"stream","text":["NN(\n","  (hidden): ModuleList(\n","    (0): Linear(in_features=500, out_features=1024, bias=True)\n","    (1): Linear(in_features=1024, out_features=512, bias=True)\n","    (2): Linear(in_features=512, out_features=3, bias=True)\n","  )\n",")\n"]}],"source":["###model definition\n","import models\n","importlib.reload(models)\n","from models import VATLoss\n","\n","\n","y_dim = 3\n","#h_dim = [2048, 1024]\n","h_dim = [1024, 512]\n","#h_dim = [512, 256]\n","n_feature = 500\n","\n","dropout_rate = 0.03\n","lr = 1e-4\n","momentum = 0.4\n","xi = 0.053\n","eps = 0.049\n","ip = 1\n","\n","alpha = 1\n","#alpha = 1 * (len(unlabelled) + len(labelled)) / len(labelled)\n","\n","model = NN([n_feature, y_dim, h_dim],dropout_rate=dropout_rate)\n","print(model)\n","\n","#optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=momentum)\n","optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))\n","if cuda: model = model.cuda()"]},{"cell_type":"code","execution_count":null,"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":1000},"executionInfo":{"elapsed":206292,"status":"error","timestamp":1694011287677,"user":{"displayName":"Shota Saito","userId":"16137182175427338624"},"user_tz":-540},"id":"7_unR5VqBtNX","outputId":"21c73766-39cb-4d11-f2df-996080efe623"},"outputs":[],"source":[]}],"metadata":{"accelerator":"GPU","colab":{"machine_shape":"hm","provenance":[]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.11"}},"nbformat":4,"nbformat_minor":0}
