#pretrain on cifar100
# vgg pretrain -> refer comments 
import argparse
import torch
import torch.nn as nn
import logging
import os
import sys
sys.path.append('../')
sys.path.append('../../')
from util.accumulator import *
from vgg import  *
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
from tensorboardX import SummaryWriter
from torchsummary import summary
from resnet18.model import Net
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='cifar100 vgg19/resnet18 pretrain ', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--save_dir', type=str, default='./pretrain', help='savedir name')
parser.add_argument('--id', default="resnet", type=str)
parser.add_argument('--lr', default=1e-4, type=float)
parser.add_argument('--sgd', default=False, action='store_true')
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./runs/pretrain',args.id))


logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)



# CIFAR100 Dataset
train_dataset = datasets.CIFAR100(root='../dataset/cifar100', train=True,
                              transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]), download=True)
test_dataset = datasets.CIFAR100(root='../dataset/cifar100', train=False,
                             transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)


#net = VGG_CIFAR(num_classes=100)
net = Net(num_classes=100)

net=nn.DataParallel(net)
net.cuda()

optimizer = torch.optim.Adam(net.parameters(), lr=args.lr,weight_decay=1e-4)
if args.sgd:
    optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)


lr_scheduler= torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.5, .8]],
        gamma=0.1)
criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')

num=0
pnum=0



def train():
    for epoch in range(1, args.num_epochs+1):
        accm.reset()
        line = 'epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()


        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            outs = net(x)
            loss = criterion(outs, y)
            loss.backward()
            optimizer.step()

            accm.update([loss.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line )
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        lr_scheduler.step()


        if epoch % args.eval_freq == 0:
            test( epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, logname+"model.tar"))
    writer.close()
def test( epoch=None):
    net.eval()
    accm.reset()
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()
        outs = net(x)
        loss = criterion(outs, y)
        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))
    logger.info('\n' )
    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)



if __name__ == '__main__':
    train()
