
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import data_loader
import numpy as np
import torchvision.utils as vutils
import models
import os

from torchvision import datasets, transforms
from torch.autograd import Variable





# Training settings
parser = argparse.ArgumentParser(description='PyTorch code: SDL')
parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') #128
parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='how many batches to wait before logging training status')
parser.add_argument('--dataset', required=True, help='cifar10 | cifar100')
parser.add_argument('--dataroot', default='./data/', help='path to dataset')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--outf', default='./parameters/', help='folder to output images and model checkpoints')
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
parser.add_argument('--droprate', type=float, default=0.1, help='learning rate decay')
parser.add_argument('--decreasing_lr', default='40,80', help='decreasing strategy')
parser.add_argument('--net_type', default='densenet', help="Type of Classification Nets")
parser.add_argument('--optimizer_flag', default='sgd', help="Type of optimizer")
parser.add_argument('--numclass', type=int, default=10, help='the # of classes')
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument('--sparseK', type=int, default=3, help='number of sparsity')
parser.add_argument('--beta', type=float, default=100, help='coefficient of reconstruction loss')
parser.add_argument('--n', type=int, default=7, help='half of number of columns of structured matrix B')
parser.add_argument('--sigma', type=float, default=0.3, help='std of noise')

args = parser.parse_args()
print(args)
args.cuda = not args.no_cuda and torch.cuda.is_available()
print("Random Seed: ", args.seed)
torch.manual_seed(args.seed)

sn = args.n  # a prime number for structured matrix B  ( B has 2*sn number of columns )

beta = args.beta
sparseK =args.sparseK
B= np.loadtxt('B.txt',delimiter=",");

sigma= args.sigma


txtfile = './results/' +  args.dataset + '_Gaussian_' + str(sigma) + '_'+ args.net_type + '_seed_' + str(args.seed) + '.txt'

with open(txtfile, "w+") as myfile:
        myfile.write('epoch' + ' '  + 'test_acc' + ' '  +  'noise_test_acc' + "\n")

if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu)
    
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

print('load data: ', args.dataset)

if args.dataset == 'cifar10' or args.dataset == 'cifar100':

    transform_train = transforms.Compose([
        transforms.RandomCrop(args.imageSize, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
    ])

    if args.dataset == 'cifar100':
        args.numclass = 100
        args.decreasing_lr = '80,120,160'
        args.epochs = 200

    train_loader, _ = data_loader.getTargetDataSet(args.dataset, args.batch_size, transform_train, args.dataroot)
    _, test_loader = data_loader.getTargetDataSet(args.dataset, args.batch_size, transform_test, args.dataroot)




    
print('Model: ', args.net_type)
if args.net_type == 'densenet':
    model = models.DenseNet3(B,sparseK,sn,100, int(args.numclass))
elif args.net_type == 'resnet34':
    model = models.ResNet34(B,sparseK,sn,num_c=args.numclass) 

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
decreasing_lr = list(map(int, args.decreasing_lr.split(',')))

lossRE = nn.MSELoss()
lossCE = nn.CrossEntropyLoss()

def test(epoch,txtfile):
    model.eval()
    test_loss, correct, total = 0, 0, 0
    clean_loss, clean_correct, clean_total = 0, 0, 0
    for data, target in test_loader:
        total += data.size(0)
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        with torch.no_grad():
            dataP = data + sigma* Variable( torch.randn(data.shape).cuda())
            output,_  = model(dataP)
            pred = output.data.max(1)[1] # get the index of the max log-probability
            correct += pred.eq(target.data).cpu().sum()
            output,_  = model(data)
            pred = output.data.max(1)[1]
            clean_correct += pred.eq(target.data).cpu().sum()
    test_acc = 100. * correct.numpy() / float(total)
    clean_acc = 100. * clean_correct.numpy() / float(total)
    with open(txtfile, "a") as myfile:
        myfile.write(str(int(epoch)) + ' '  + str(clean_acc) +  ' '  + str(test_acc) +  "\n")
    print('\nEpoch: {:.0f},  Clean Test Accuracy: {}/{} ({:.2f}%), Accuracy: {}/{} ({:.2f}%)\n'.format(
        epoch,  clean_correct.numpy(), total,
        100. * clean_correct.numpy() / float(total), correct.numpy(), total,
        100. * correct.numpy() / float( total)))

    

args.outf = args.outf + '/SDL/' + args.net_type + '/' + args.dataset  + '_'+ args.net_type + '_seed_' + str(args.seed) + '/'

if os.path.isdir(args.outf) == False:
    os.makedirs(args.outf)


for epoch in range(1, args.epochs + 1):
    if epoch % 5 ==0 or epoch==1:        
        TestPath = args.outf+ str(epoch) + 'model.pth'
        model.load_state_dict(torch.load(TestPath))
        test(epoch,txtfile)



