
from __future__ import print_function
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import data_loader
import numpy as np
import torchvision.utils as vutils
import models
import os

from torchvision import datasets, transforms
from torch.autograd import Variable





# Training settings
parser = argparse.ArgumentParser(description='PyTorch code: SDL')
parser.add_argument('--batch-size', type=int, default=64, metavar='N', help='input batch size for training (default: 64)') #128
parser.add_argument('--epochs', type=int, default=120, metavar='N', help='number of epochs to train (default: 200)')
parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='learning rate (default: 0.1)')
parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)')
parser.add_argument('--no-cuda', action='store_true', default=False, help='disables CUDA training')
parser.add_argument('--seed', type=int, default=1, metavar='S', help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=100, metavar='N', help='how many batches to wait before logging training status')
parser.add_argument('--dataset', required=True, help='cifar10 | cifar100')
parser.add_argument('--dataroot', default='./data/', help='path to dataset')
parser.add_argument('--imageSize', type=int, default=32, help='the height / width of the input image to network')
parser.add_argument('--outf', default='./parameters/', help='folder to output images and model checkpoints')
parser.add_argument('--wd', type=float, default=0.0005, help='weight decay')
parser.add_argument('--droprate', type=float, default=0.1, help='learning rate decay')
parser.add_argument('--decreasing_lr', default='40,80', help='decreasing strategy')
parser.add_argument('--net_type', default='densenet', help="Type of Classification Nets")
parser.add_argument('--optimizer_flag', default='sgd', help="Type of optimizer")
parser.add_argument('--numclass', type=int, default=10, help='the # of classes')
parser.add_argument('--gpu', type=int, default=0, help='gpu index')
parser.add_argument('--sparseK', type=int, default=3, help='number of sparsity')
parser.add_argument('--beta', type=float, default=100, help='coefficient of reconstruction loss')
parser.add_argument('--n', type=int, default=7, help='half of number of columns of structured matrix B')


args = parser.parse_args()
print(args)
args.cuda = not args.no_cuda and torch.cuda.is_available()
print("Random Seed: ", args.seed)
torch.manual_seed(args.seed)

sn = args.n  # a prime number for structured matrix B  ( B has 2*sn number of columns )

beta = args.beta
sparseK =args.sparseK
B= np.loadtxt('B.txt',delimiter=",");


if args.cuda:
    torch.cuda.manual_seed(args.seed)
    torch.cuda.set_device(args.gpu)
    
kwargs = {'num_workers': 1, 'pin_memory': True} if args.cuda else {}

print('load data: ', args.dataset)

if args.dataset == 'cifar10' or args.dataset == 'cifar100':

    transform_train = transforms.Compose([
        transforms.RandomCrop(args.imageSize, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
    ])
    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((125.3/255, 123.0/255, 113.9/255), (63.0/255, 62.1/255.0, 66.7/255.0)),
    ])

    if args.dataset == 'cifar100':
        args.numclass = 100
        args.decreasing_lr = '80,120,160'
        args.epochs = 200

    train_loader, _ = data_loader.getTargetDataSet(args.dataset, args.batch_size, transform_train, args.dataroot)
    _, test_loader = data_loader.getTargetDataSet(args.dataset, args.batch_size, transform_test, args.dataroot)




    
print('Model: ', args.net_type)
if args.net_type == 'densenet':
    model = models.DenseNet3(B,sparseK,sn,100, int(args.numclass))
elif args.net_type == 'resnet34':
    model = models.ResNet34(B,sparseK,sn,num_c=args.numclass) 

if args.cuda:
    model.cuda()

optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=args.momentum, weight_decay=args.wd)
decreasing_lr = list(map(int, args.decreasing_lr.split(',')))

lossRE = nn.MSELoss()
lossCE = nn.CrossEntropyLoss()

def train(epoch):
    model.train() 
    total = 0
    for batch_idx, (data, target) in enumerate(train_loader):
        total += data.size(0)
        if args.cuda:
            data, target = data.cuda(), target.cuda()
        data, target = Variable(data), Variable(target)                
        optimizer.zero_grad()
        output, ReX = model(data)
        
        Closs = lossCE(output, target)
        Reloss = lossRE(ReX,data)
        loss = Closs +  beta*Reloss
       
        loss.backward()
        optimizer.step()
        if batch_idx % args.log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tReLoss: {:.12f}\tCloss:  {:.6f}'.format(
                epoch, batch_idx * len(data), total,
                100. * batch_idx / float(total), loss.data, Reloss.data, Closs.data))


    

args.outf = args.outf + '/SDL/' + args.net_type + '/' + args.dataset  + '_'+ args.net_type + '_seed_' + str(args.seed) + '/'

if os.path.isdir(args.outf) == False:
    os.makedirs(args.outf)


for epoch in range(1, args.epochs + 1):
    train(epoch)
    if epoch in decreasing_lr:
        optimizer.param_groups[0]['lr'] *= args.droprate
    if epoch % 5 ==0 or epoch==1:        
        torch.save(model.state_dict(), args.outf+ str(epoch) + 'model.pth' )



