Source code for deeprobust.graph.black_box

import torch
from deeprobust.graph.defense import GCN
import pickle
import os.path as osp
from deeprobust.graph.data import Dataset
from deeprobust.graph.utils import preprocess
import os

[docs]def load_victim_model(data, model_name='gcn', device='cpu', file_path=None): """load_victim_model. Parameters ---------- data : deeprobust.graph.Dataset graph data model_name : str victime model name, e.g. ('gcn', 'deepwalk') But currently it only supports gcn as victim model. device : str 'cpu' or 'cuda' file_path : if given, the victim model will be loaded from this path. """ assert model_name == 'gcn', 'Currently only support gcn as victim model...' if file_path is None: # file_path = f'results/saved_models/{data.name}/{model_name}_checkpoint' file_path = 'results/saved_models/{0}/{1}_checkpoint'.format(data.name, model_name) else: file_path = osp.join(file_path, '{}_checkpoint'.format(model_name)) # Setup victim model if osp.exists(file_path): victim_model = GCN(nfeat=data.features.shape[1], nclass=data.labels.max().item()+1, nhid=16, dropout=0.5, weight_decay=5e-4, device=device) victim_model.load_state_dict(torch.load(file_path, map_location=device)) victim_model.to(device) victim_model.eval() return victim_model victim_model = train_victim_model(data=data, model_name=model_name, device=device, file_path=osp.dirname(file_path)) return victim_model
[docs]def train_victim_model(data, model_name='gcn', file_path=None, device='cpu'): """Train the victim model (target classifer) and save the model Note that the attacker can only do black query to this model. """ if file_path is None: file_path = 'results/saved_models/%s/' % data.name adj, features, labels = data.adj, data.features, data.labels idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test nfeat = features.shape[1] adj, features, labels = preprocess(adj, features, labels, preprocess_adj=False) # Setup victim model victim_model = GCN(nfeat=features.shape[1], nclass=labels.max().item()+1, nhid=16, dropout=0.5, weight_decay=5e-4, device=device) adj = adj.to(device) features = features.to(device) labels = labels.to(device) victim_model = victim_model.to(device) victim_model.fit(features, adj, labels, idx_train, idx_val) # save the model if not osp.exists(file_path): os.system('mkdir -p %s' % file_path) torch.save(victim_model.state_dict(), osp.join(file_path, model_name + '_checkpoint')) victim_model.eval() return victim_model