from utils import *
from funcs import train_ann
import argparse
from PreProcess import GetCifar10, GetCifar100
from Models.ResNet import *
from Models.VGG import *


parser = argparse.ArgumentParser()
parser.add_argument('--dataset', type=str, default='CIFAR10', help='Dataset name')
parser.add_argument('--datadir', type=str, default='../datasets', help='Dataset location')
parser.add_argument('--arch', type=str, default='vgg16', help='Architecture')
parser.add_argument('--savename', type=str, default='MyModel', help='Model saving name')
parser.add_argument('--device', type=str, default='cuda:0', help='Device')
parser.add_argument('--epochs', type=int, default=300, help='Number of epochs')
parser.add_argument('--batchsize', type=int, default=50, help='Batch size')
parser.add_argument('--L', type=int, default=4, help='Quantization time-step')
parser.add_argument('--lr', type=float, default=0.1, help='Learning rate')
parser.add_argument('--wd', type=float, default=5e-4, help='Weight decay')

args = parser.parse_args()

# get data
if args.dataset.lower() == 'cifar10':
    train, test = GetCifar10(args.datadir, args.batchsize)
    cls = 10
elif args.dataset.lower() == 'cifar100':
    train, test = GetCifar100(args.datadir, args.batchsize)
    cls = 100
else:
    error('unable to find dataset ' + args.dataset)

# get model
if args.arch.lower() == 'resnet20':
    model = resnet20(num_classes=cls)
elif args.arch.lower() == 'resnet18':
    model = resnet18(num_classes=cls)
elif args.arch.lower() == 'vgg16':
    model = vgg16(num_classes=cls)
else:
    error('unable to find model ' + args.arch)

# use avgpooling instead of maxpooling
model = replace_maxpool2d_by_avgpool2d(model)
# use trainable clipping layer
model = replace_activation_by_floor(model, args.L)

# training
model = train_ann(train, test, model, epochs=args.epochs, lr=args.lr, wd=args.wd, device=args.device)

# saving
torch.save(model, args.savename + '.pkl')
