# pruned architecture: --p 32 32 32 32 32 32 32 32 32  32 32 32 32 32 32 32

#structure accuracy check (random/metapruning/stamp)
#for svhn --data_id 1 
import argparse
import torch
import torch.nn as nn
import logging
import os
import sys
import math
import random
sys.path.append('../dad')
sys.path.append('../../')
from util.accumulator import *
from module.bbdrop import BBDropout
from set.dbbdrop import DBBDropout
from util.accumulator import *

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
from tensorboardX import SummaryWriter
from torchsummary import summary
import pickle
from data_sampler import *
from util.real_flops import *

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='structure accuracy check (random/metapruning/stamp)', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--set_size', type=int, default=64, help='batch size')

parser.add_argument('--save_dir', type=str, default='./structure', help='savedir name')
parser.add_argument('--id', default="_", type=str)
parser.add_argument('--lr', default=1e-1, type=float)
parser.add_argument('--gamma', default=1./60000, type=float)
parser.add_argument('--kl_scale', default=1.0, type=float)
parser.add_argument('--a',default=2, type=int, help='auc')
parser.add_argument('--dkl_scale', default=1.0, type=float)
parser.add_argument('--model_lr', default=1e-3, type=float)
parser.add_argument('--set_lr', default=1e-3, type=float)
parser.add_argument('--shot', default=1, type=int)
parser.add_argument('--rho', default=5, type=float)
parser.add_argument('--data_id',default=-1, type=int, help='data')
parser.add_argument('--adam',default=False, action='store_true')
parser.add_argument('--p', nargs='*',type=int)

parser.add_argument('--bb_lr', default=0.01, type=float)
parser.add_argument('--dbb_lr', default=0.01, type=float)
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./structure/runs',args.id))



l=[i for i in range(100)]
cl=[i for i in range(5000)]  # TODO:  5000
shuffle=[i for i in range(args.set_size*10)]
batch_num = args.set_size * 10 // args.batch_size

logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)


def test( epoch=None):
    net.eval()
    accm.reset()
    tt=0
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()

        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()
        outs = net(x)


        # whatever you are timing goes here
        end.record()
        # Waits for everything to finish running
        torch.cuda.synchronize()
        t=start.elapsed_time(end)

        tt+=t

        loss = criterion(outs, y)
        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))

    logger.info('infer {:.4f}'.format(tt/1000))

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#



# CIFAR100 Dataset
train_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True,
                              transform=transforms.Compose([
                                            transforms.RandomCrop(32, padding=4),
                                            transforms.RandomHorizontalFlip(),
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]), download=True)
test_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=False,
                             transform=transforms.Compose([
                                            transforms.ToTensor(),
                                            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                        ]))
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)


criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')
TIME = Accumulator('time')
TIME.reset()

if args.data_id>=0:
    D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,setsize=args.set_size,cls_idx=[0,1,2,3,4,5,6,7,8,9])
    S=D.sample(train=False)
    _,train_loader,test_loader=D.get_loader()
    if args.data_id==0:
        logger.info("mnist")
    else:
        logger.info("svhn")
"""
#PRUNED=adapt_channel([] )#[15. 28. 19. 19.  9. 15. 15.  2. 15.  8. 11. 26.  4. 30.  0. 20.] Top-1 err = 85.5        / lr 0.01 -> acc : 24 at training
PRUNED=[27, 36, 30, 34, 64, 67, 75, 63, 131, 121, 131, 126, 247, 114, 255, 210]
PRUNED=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512]
P=[64,64,64,64,128,128,128,128,256,256,256,256,512,512,512,512]
PRUNED=[31, 41, 36, 31, 69, 62, 66, 77, 123, 89, 100, 120, 100, 112, 106, 156]#OURS x4.25
PRUNED=[]
while(True):
    PRUNED=[]
    for c in P:
        p=random.random()
        PRUNED.append(int(c*p))
    r=real_res(PRUNED,logger)

    if 132<r and 150>r:
        break
"""
PRUNED =args.p
logger.info(PRUNED)
real_res(PRUNED,logger)

net=RESNET_PRUNED(num_classes=10,P=PRUNED,logger=logger).cuda()


#net=net_.cuda()

#optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
if args.adam:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.model_lr, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.4,.6]],
        gamma=0.1)
criterion = nn.CrossEntropyLoss().cuda()



def train():
    for epoch in range(args.shot, args.num_epochs):
        accm.reset()
        line = '\n epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()
        tt=0

        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()
            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)
            #torch.cuda.synchronize()
            start.record()
            outs = net(x)
            loss = criterion(outs, y)
            loss.backward()
            optimizer.step()
            # whatever you are timing goes here
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            t=start.elapsed_time(end)

            tt+=t
            TIME.update([t/1000])#
            writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#


            accm.update([loss.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line )
        logger.info(tt/1000)
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        logger.info("Time spent: {}".format(TIME.get_sum('time')))#
        writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#


        lr_scheduler.step()



        if epoch % args.eval_freq == 0:
            test( epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, logname+"model.tar"))
    writer.close()



if __name__ == '__main__':
    train()
