#cifar100->cifar10 vgg19 bbdropout
import argparse
import torch
import torch.nn as nn
import torchvision.datasets as dsets
import torchvision.transforms as transforms
from torch.utils.data import Dataset

import logging
import os
import sys
sys.path.append('../')
sys.path.append('../../')

from util.accumulator import *
from module.bbdrop import BBDropout
from module.dbbdrop import DBBDropout

from vgg import VGG_CIFAR_gated,VGG_CIFAR
import time
from tensorboardX import SummaryWriter
from data_sampler import *
from util.real_flops import real

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='cifar100->cifar10 vgg19 bbdropout', type=str)
parser.add_argument('--gpu', default=0, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--seed', default=42, type=int)

parser.add_argument('--save_freq', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--task_size', type=int, default=50000, help='task size')

parser.add_argument('--gamma', default=1./60000, type=float)
parser.add_argument('--kl_scale', default=1.0, type=float)
parser.add_argument('--a',default=2, type=int, help='auc')
parser.add_argument('--sgd',default=False, action='store_true')

parser.add_argument('--pretrain_path', default='./pretrain/_model.tar', type=str)
parser.add_argument('--save_dir', type=str, default='./bb/vgg', help='savedir name')
parser.add_argument('--id', default='_',type=str)
parser.add_argument('--data_id', default=-1,type=int)
parser.add_argument('--model_lr', default=1e-3, type=float)
parser.add_argument('--bb_lr', default=1e-2, type=float)
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./bb/vgg/runs',args.id))
LOGNAME = args.id

if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, LOGNAME+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu)
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)


net = VGG_CIFAR_gated(num_classes=10)
pdict=torch.load(args.pretrain_path)['state_dict']
dict=net.state_dict()
new_dict={}

for pk, k in zip(pdict.keys(),dict.keys()):
    if k.replace("base.","") == pk.replace('module.','') and 'classifier' not in pk :
        new_dict[k]=pdict[pk]

dict.update(new_dict)
net.load_state_dict(dict)


net.build_gate(BBDropout, {'a_uc_init': args.a, 'kl_scale': args.kl_scale})
net=nn.DataParallel(net)
net.cuda()



base_params = []
gate_params = []
for name, param in net.named_parameters():
    if 'gate' in name:
        gate_params.append(param)
    else:
        base_params.append(param)

optimizer = torch.optim.Adam([
    {'params':gate_params, 'lr':args.bb_lr},
    {'params':base_params, 'lr':args.model_lr, 'weight_decay':1e-4}])
if args.sgd:
    optimizer = torch.optim.SGD([
        {'params':gate_params, 'lr':args.bb_lr,'weight_decay':1e-4},
        {'params':base_params, 'lr':args.model_lr, 'weight_decay':1e-4}], momentum=0.9, nesterov=True)

scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.4,.6]],
        gamma=0.1)


criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')
TIME = Accumulator('time')
TIME.reset()

train_dataset = dsets.CIFAR10(root='../dataset/cifar10', train=True,
                              transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]), download=True)
test_dataset = dsets.CIFAR10(root='../dataset/cifar10', train=False,
                             transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
if args.task_size != 50000:
    train_dataset,_=torch.utils.data.random_split(train_dataset, [args.task_size,50000-args.task_size])
    print(len(train_dataset))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)

"""
if args.data_id>=0:
    D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,cls_idx=[0,1,2,3,4,5,6,7,8,9])
    train_loader,test_loader=D.get_loader()
    if args.data_id==0:
        logger.info("mnist")
    else:
        logger.info("svhn")
"""
if args.data_id>=0:
    if args.data_id ==2:
        D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,cls_idx=range(100),task_size=args.task_size)
        S=D.sample(train=False)
        train_loader0,train_loader,test_loader=D.get_loader()
    else:
        D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,cls_idx=[0,1,2,3,4,5,6,7,8,9],task_size=args.task_size)
        S=D.sample(train=False)
        train_loader0,train_loader,test_loader=D.get_loader()
        if args.data_id==0:
            logger.info("mnist")
        else:
            logger.info("svhn")

def train():

    for epoch in range(1, args.num_epochs+1):
        accm.reset()
        line = 'epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()
        tt=0
        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            torch.cuda.synchronize()

            start.record()
            # whatever you are timing goes here
            outs = net(x)
            cent = criterion(outs, y)
            reg = net.module.get_reg().cuda()
            loss = cent + args.gamma*reg
            loss.backward()
            optimizer.step()

            end.record()
            torch.cuda.synchronize()
            t=start.elapsed_time(end)
            tt+=t
            TIME.update([t/1000])#
            writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#

            accm.update([cent.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line)
        scheduler.step()
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        logger.info("{:.4f}".format(tt/1000))#

        logger.info("Time spent: {}".format(TIME.get_sum('time')))#
        writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#


        if epoch % args.eval_freq == 0:
            P=test(epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, LOGNAME+"model.tar"))
    return P
def test(epoch=None):
    net.eval()
    accm.reset()
    tt=0
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()

        start.record()
        # whatever you are timing goes here
        outs = net(x)
        cent = criterion(outs, y)

        end.record()
        torch.cuda.synchronize()
        t=start.elapsed_time(end)
        tt+=t
        accm.update([cent.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))
    logger.info('reg {:.4f}'.format(net.module.get_reg().item()))
    logger.info('pruned size {}'.format(str(net.module.get_pruned_size())))
    logger.info('nonactive weight {} ({:.4f}%)'.format(str(net.module.get_pruned_weight_sum()),net.module.get_pruned_weight_sum()*100/20040522))
    logger.info("infer {:.4f}".format(tt/1000))
    logger.info('speedup in flops {:.4f}'.format(net.module.get_speedup()))
    logger.info('memory saving {:.4f}\n'.format(net.module.get_memory_saving()))

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/nonactive_weights", net.module.get_pruned_weight_sum()*100/20040522,epoch)
    writer.add_scalar("validation/flop_speedup", net.module.get_speedup(),epoch)
    writer.add_scalar("validation/memory_saved", net.module.get_memory_saving(),epoch)

    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#
    return net.module.get_pruned_size()


if __name__ == '__main__':
    P=train()
    writer.close()
    real(P,logger)
