# pruned architecture: --p 32 32 32 32 32 32 32 32 32  32 32 32 32 32 32 32

#structure accuracy check (random/metapruning/stamp)
import argparse
import torch
import torch.nn as nn
import logging
import os
import sys
import math
import random
sys.path.append('../')
sys.path.append('../../')
from util.accumulator import *
from vgg import  *
from prunedvgg import  *
from module.bbdrop import BBDropout

from set.dbbdrop import DBBDropout
from util.accumulator import *

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
from tensorboardX import SummaryWriter
from torchsummary import summary
import pickle
from recal import *
from data_sampler import *
from util.real_flops import real

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='structure accuracy check (random/metapruning/stamp)', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--set_size', type=int, default=64, help='batch size')

parser.add_argument('--pretrain_path', type=str, default='./metatrain/_2000.tar', help='savedir name')
parser.add_argument('--save_dir', type=str, default='./structure', help='savedir name')
parser.add_argument('--id', default="_", type=str)
parser.add_argument('--lr', default=1e-1, type=float)
parser.add_argument('--gamma', default=1./60000, type=float)
parser.add_argument('--kl_scale', default=1.0, type=float)
parser.add_argument('--a',default=2, type=int, help='auc')
parser.add_argument('--dkl_scale', default=1.0, type=float)
parser.add_argument('--model_lr', default=1e-3, type=float)
parser.add_argument('--set_lr', default=1e-3, type=float)
parser.add_argument('--shot', default=1, type=int)
parser.add_argument('--rho', default=5, type=float)
parser.add_argument('--data_id',default=-1, type=int, help='data')
parser.add_argument('--adam',default=False, action='store_true')
parser.add_argument('--p', nargs='*',type=int)
parser.add_argument('--meta',default=False, action='store_true')

parser.add_argument('--bb_lr', default=0.01, type=float)
parser.add_argument('--dbb_lr', default=0.01, type=float)
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./structure/runs',args.id))

l=[i for i in range(100)]
cl=[i for i in range(5000)]  # TODO:  5000
shuffle=[i for i in range(args.set_size*10)]
batch_num = args.set_size * 10 // args.batch_size


channel_scale = []
for i in range(31):
    channel_scale += [(10 + i * 3)/100]
def adapt_channel(ids):

    stage_oup_scale_ids = []
    for i in range(15):
        stage_oup_scale_ids += [ids[i]]
    stage_oup_scale_ids +=[-1]
    print(stage_oup_scale_ids)
    stage_out_channel = [64, 64, 128,128,   256, 256, 256, 256,  512, 512, 512, 512, 512, 512, 512, 512]
    overall_channel=[]
    for i in range(len(stage_out_channel)):
        overall_channel += [int(stage_out_channel[i] * channel_scale[stage_oup_scale_ids[i]])]
    print(overall_channel)

    return overall_channel
logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)


def test( epoch=None):
    net.eval()
    accm.reset()
    tt=0
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()
        outs = net(x)


        # whatever you are timing goes here
        end.record()
        # Waits for everything to finish running
        torch.cuda.synchronize()
        t=start.elapsed_time(end)

        tt+=t

        loss = criterion(outs, y)
        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))

    logger.info('infer {:.4f}'.format(tt/1000))

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#



train_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True,
                              transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]), download=True)
test_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=False,
                             transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)


criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')
TIME = Accumulator('time')
TIME.reset()

if args.data_id>=0:
    D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,setsize=args.set_size,cls_idx=[0,1,2,3,4,5,6,7,8,9])
    S=D.sample(train=False)
    _,train_loader,test_loader=D.get_loader()
    if args.data_id==0:
        logger.info("mnist")
    else:
        logger.info("svhn")
#PRUNED=adapt_channel([15, 28, 19, 19,  9, 15, 15,  2, 15,  8, 11, 26,  4, 30,  0, 20] )#[15. 28. 19. 19.  9. 15. 15.  2. 15.  8. 11. 26.  4. 30.  0. 20.] Top-1 err = 85.5        / lr 0.01 -> acc : 24 at training
#PRUNED=adapt_channel([19, 18, 10, 18, 14,  5, 11,  9, 27,  6, 25, 22, 11, 18,  9, 28] )#SVHN Top-1 err = 86.27000427246094      / lr 0.01 -> acc : 24 at training
#PRUNED=[27, 23, 62, 66, 156, 163, 163, 171, 358, 189, 327, 343, 235, 250, 296, 235] #random (<FLOP CONSTERAINT)

#PRUNED=[64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512,  512, 512, 512, 512]
#PRUNED=[46, 64, 107, 102, 132, 91, 70, 52, 51, 39, 45, 33, 37, 43, 49, 134] #our structure cifar10

#PRUNED=adapt_channel([19, 18, 10, 18, 14,  5, 11,  9, 27,  6, 25, 22, 11, 18,  9, 28] )#SVHN Top-1 err = 86.27000427246094      / lr 0.01 -> acc : 24 at training


#PRUNED=[33, 31, 14, 16, 16, 39, 24, 66, 125, 117, 125, 48, 220, 235, 204, 512] #svhn flop x16  randomxxxxxxx
#PRUNED=adapt_channel([22, 13, 14,  3, 19, 24, 10, 13,  8,  1,  8,  1,  0,  5,  3, 13] )#  svhnx16 :85  / lr 0.001 -> acc : 48 at training

#PRUNED=adapt_channel([23,        14,         11,          6,         12,         25, 10,         14,         25,         20,         10,         26, 0,          5,          6,         12     ] )#cifar48
  # [ 2.  1. 11.  3.  3.  4. 12.  2. 12.  8.  1.  2.  1.  5.  5.  4.] Top-1 err = 84.48999786376953
#No.2 [ 2.  1. 11.  3.  3.  4. 12.  2. 12.  8.  1.  2.  1.  5.  5.  1.] Top-1 err = 84.5
#No.3 [ 2.  1. 11.  3.  3.  4. 12.  2. 11.  8.  1.  2.  1.  5.  5.  3.]

#PRUNED=adapt_channel([ 2 , 1 ,11  ,3 , 3 , 4 ,12 , 2 ,12 , 8  ,1 , 2 , 1  ,5  ,5 , 4])#svhn flop x16  top1
#PRUNED=adapt_channel([ 2 , 1 ,11  ,3 , 3 , 4 ,11 , 2 ,12 , 8  ,1 , 2 , 1  ,5  ,5 , 4])#svhn flop x16  top2
#PRUNED=adapt_channel([ 2 , 1 ,11  ,3 , 3 , 4 ,12 , 2 ,7 , 8  ,1 , 2 , 1  ,5  ,5 , 1])#svhn flop x16  top3

PRUNED =args.p
if args.meta :
    PRUNED=adapt_channel(PRUNED)
logger.info(PRUNED)
real(PRUNED,logger)

net=VGG_CIFAR_PRUNED(num_classes=10,pruned=PRUNED,logger=logger).cuda()


#net=net_.cuda()

#optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
if args.adam:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.model_lr, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.4,.6]],
        gamma=0.1)
criterion = nn.CrossEntropyLoss().cuda()



def train():
    for epoch in range(args.shot, args.num_epochs):
        accm.reset()
        line = '\n epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()
        tt=0

        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            #torch.cuda.synchronize()
            start.record()
            outs = net(x)
            loss = criterion(outs, y)
            loss.backward()
            optimizer.step()
            # whatever you are timing goes here
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            t=start.elapsed_time(end)

            tt+=t
            TIME.update([t/1000])#
            writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#


            accm.update([loss.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line )
        logger.info(tt/1000)
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        logger.info("Time spent: {}".format(TIME.get_sum('time')))#
        writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#


        lr_scheduler.step()



        if epoch % args.eval_freq == 0:
            test( epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, logname+"model.tar"))
    writer.close()



if __name__ == '__main__':
    train()
