##load resnet pretrained on cifar100 and train on cifar10


import argparse
import torch
import torch.nn as nn
import logging
import os
import sys
sys.path.append('../')
sys.path.append('../../')
from util.accumulator import *
from resnet18.model import  Net
from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
from tensorboardX import SummaryWriter
from torchsummary import summary
from data_sampler import *
parser = argparse.ArgumentParser()
parser.add_argument('--config', default='load resnet pretrained on cifar100 and train on cifar10', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--pretrain_path', type=str, default='./pretrain/resnetmodel.tar', help='cifar100 path')
parser.add_argument('--data_id', default=-1, type=int)

parser.add_argument('--save_dir', type=str, default='./full/', help='savedir name')
parser.add_argument('--id', default="resnet", type=str)
parser.add_argument('--lr', default=1e-1, type=float)
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./full/runs/',args.id))


logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)

if args.data_id==-1:
    train_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True,
                                  transform=transforms.Compose([
                                                transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                            ]), download=True)
    test_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=False,
                                 transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                            ]))
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)


else:
    D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,cls_idx=[0,1,2,3,4,5,6,7,8,9])
    _,train_loader,test_loader=D.get_loader()
    if args.data_id==0:
        logger.info("mnist")
    else:
        logger.info("svhn")
net = Net(num_classes= 10)


dict=torch.load(args.pretrain_path)['state_dict']
new_dict={}

for k in dict.keys():
    if 'linear' not in k  :
        kk=k.replace('module.','')
        new_dict[kk]=dict[k]
dict=net.state_dict()
dict.update(new_dict)
net.load_state_dict(dict)


net=nn.DataParallel(net)
net.cuda()


optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.4,.6]],
        gamma=0.1)


criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')
TIME = Accumulator('time')
TIME.reset()
num=0
pnum=0



def train():

    for epoch in range(1, args.num_epochs+1):
        accm.reset()
        line = 'epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()


        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()



            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            # whatever you are timing goes here
            outs = net(x)
            loss = criterion(outs, y)
            loss.backward()
            optimizer.step()
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            t=start.elapsed_time(end)

            TIME.update([t/1000])#
            writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#
            accm.update([loss.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line )
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        logger.info("Time spent: {}".format(TIME.get_sum('time')))#
        writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#

        lr_scheduler.step()


        if epoch % args.eval_freq == 0:
            test( epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, logname+"model.tar"))
    writer.close()

def test( epoch=None):
    net.eval()
    accm.reset()
    tt=0
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)

        start.record()
        # whatever you are timing goes here
        outs = net(x)
        loss = criterion(outs, y)

        end.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()
        t=start.elapsed_time(end)
        tt+=t


        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))
    logger.info('\n' )
    logger.info('infer {:.4f}'.format(tt/1000) )
    #logger.info('avg {:.4f}'.format(ttt/200000) )

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#



if __name__ == '__main__':
    train()
