Source code for deeprobust.image.utils

import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
import urllib.request

import os

[docs]def create_train_dataset(batch_size = 128, root = '../data'): """ Create different training dataset """ transform_train = transforms.Compose([ transforms.ToTensor(), ]) trainset = torchvision.datasets.MNIST(root=root, train=True, download=True, transform=transform_train) trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2) return trainloader
def create_test_dataset(batch_size = 128, root = '../data'): transform_test = transforms.Compose([ transforms.ToTensor(), ]) testset = torchvision.datasets.MNIST(root=root, train=False, download=True, transform=transform_test) testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2) return testloader def download_model(url, file): print('Dowloading from {} to {}'.format(url, file)) try: urllib.request.urlretrieve(url, file) except: raise Exception("Download failed! Make sure you have stable Internet connection and enter the right name") def save_checkpoint(now_epoch, net, optimizer, lr_scheduler, file_name): checkpoint = {'epoch': now_epoch, 'state_dict': net.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'lr_scheduler_state_dict':lr_scheduler.state_dict()} if os.path.exists(file_name): print('Overwriting {}'.format(file_name)) torch.save(checkpoint, file_name) # link_name = os.path.join(*file_name.split(os.path.sep)[:-1], 'last.checkpoint') # #print(link_name) # make_symlink(source = file_name, link_name=link_name) def load_checkpoint(file_name, net = None, optimizer = None, lr_scheduler = None): if os.path.isfile(file_name): print("=> loading checkpoint '{}'".format(file_name)) check_point = torch.load(file_name) if net is not None: print('Loading network state dict') net.load_state_dict(check_point['state_dict']) if optimizer is not None: print('Loading optimizer state dict') optimizer.load_state_dict(check_point['optimizer_state_dict']) if lr_scheduler is not None: print('Loading lr_scheduler state dict') lr_scheduler.load_state_dict(check_point['lr_scheduler_state_dict']) return check_point['epoch'] else: print("=> no checkpoint found at '{}'".format(file_name)) from texttable import Texttable
[docs]def tab_printer(args): """ Function to print the logs in a nice tabular format. input: param args: Parameters used for the model. """ args = vars(args) keys = sorted(args.keys()) t = Texttable() t.add_rows([["Parameter", "Value"]] + [[k.replace("_"," ").capitalize(), args[k]] for k in keys]) print(t.draw())
[docs]def onehot_like(a, index, value=1): """Creates an array like a, with all values set to 0 except one. Parameters ---------- a : array_like The returned one-hot array will have the same shape and dtype as this array index : int The index that should be set to `value` value : single value compatible with a.dtype The value to set at the given index Returns ------- `numpy.ndarray` One-hot array with the given value at the given location and zeros everywhere else. """ #TODO: change the note here. x = np.zeros_like(a) x[index] = value return x
def reduce_sum(x, keepdim=True): # silly PyTorch, when will you get proper reducing sums/means? for a in reversed(range(1, x.dim())): x = x.sum(a, keepdim=keepdim) return x
[docs]def arctanh(x, eps=1e-6): """ Calculate arctanh(x) """ x *= (1. - eps) return (np.log((1 + x) / (1 - x))) * 0.5
def l2r_dist(x, y, keepdim=True, eps=1e-8): d = (x - y)**2 d = reduce_sum(d, keepdim=keepdim) d += eps # to prevent infinite gradient at 0 return d.sqrt() def l2_dist(x, y, keepdim=True): d = (x - y)**2 return reduce_sum(d, keepdim=keepdim) def l1_dist(x, y, keepdim=True): d = torch.abs(x - y) return reduce_sum(d, keepdim=keepdim) def l2_norm(x, keepdim=True): norm = reduce_sum(x*x, keepdim=keepdim) return norm.sqrt() def l1_norm(x, keepdim=True): return reduce_sum(x.abs(), keepdim=keepdim)
[docs]def adjust_learning_rate(optimizer, epoch, learning_rate): """decrease the learning rate""" lr = learning_rate if epoch >= 55: lr = learning_rate * 0.1 if epoch >= 75: lr = learning_rate * 0.01 if epoch >= 90: lr = learning_rate * 0.001 for param_group in optimizer.param_groups: param_group['lr'] = lr return optimizer
def progress_bar(current, total, msg=None): global last_time, begin_time if current == 0: begin_time = time.time() # Reset for new bar. cur_len = int(TOTAL_BAR_LENGTH*current/total) rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1 sys.stdout.write(' [') for i in range(cur_len): sys.stdout.write('=') sys.stdout.write('>') for i in range(rest_len): sys.stdout.write('.') sys.stdout.write(']') cur_time = time.time() step_time = cur_time - last_time last_time = cur_time tot_time = cur_time - begin_time L = [] L.append(' Step: %s' % format_time(step_time)) L.append(' | Tot: %s' % format_time(tot_time)) if msg: L.append(' | ' + msg) msg = ''.join(L) sys.stdout.write(msg) for i in range(term_width-int(TOTAL_BAR_LENGTH)-len(msg)-3): sys.stdout.write(' ') # Go back to the center of the bar. for i in range(term_width-int(TOTAL_BAR_LENGTH/2)+2): sys.stdout.write('\b') sys.stdout.write(' %d/%d ' % (current+1, total)) if current < total-1: sys.stdout.write('\r') else: sys.stdout.write('\n') sys.stdout.flush()