#finetune metatrained resnet18 on cifar10 /svhn
#for svhn --data_id 1

import argparse
import torch
import torch.nn as nn
import logging
import os
import sys
import math
import random
sys.path.append('../')
sys.path.append('../../')
from util.accumulator import *

from pruned import  *
from model import Net_set
from module.bbdrop import BBDropout

from set.dbbdrop import DBBDropout
from util.accumulator import *

from torchvision import datasets
import torchvision.transforms as transforms
from torch.utils.data import Dataset
import numpy as np
from tensorboardX import SummaryWriter
from torchsummary import summary
import pickle
from data_sampler import *
from util.real_flops import real

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='finetune metatrained resnet18 on cifar10 /svhn', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=42, type=int)
parser.add_argument('--num_epochs', default=200, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=128, help='batch size')
parser.add_argument('--set_size', type=int, default=64, help='batch size')

parser.add_argument('--pretrain_path', type=str, default='./metatrain/_3000.tar', help='savedir name')
parser.add_argument('--save_dir', type=str, default='./finetune', help='savedir name')
parser.add_argument('--id', default="_", type=str)
parser.add_argument('--lr', default=1e-1, type=float)
parser.add_argument('--gamma', default=1./60000, type=float)
parser.add_argument('--kl_scale', default=1.0, type=float)
parser.add_argument('--a',default=2, type=int, help='auc')
parser.add_argument('--dkl_scale', default=1.0, type=float)
parser.add_argument('--model_lr', default=1e-3, type=float)
parser.add_argument('--set_lr', default=1e-3, type=float)
parser.add_argument('--shot', default=1, type=int)
parser.add_argument('--rho', default=5, type=float)
parser.add_argument('--data_id',default=-1, type=int, help='data')
parser.add_argument('--adam',default=False, action='store_true')
parser.add_argument('--glance', default=50000, type=int)

parser.add_argument('--bb_lr', default=0.5, type=float)
parser.add_argument('--dbb_lr', default=0.5, type=float)
args, _ = parser.parse_known_args()

writer = SummaryWriter(os.path.join('./finetune/runs',args.id))

l=[i for i in range(10)]
cl=[i for i in range(5000)]  # TODO:  5000
shuffle=[i for i in range(args.set_size*10)]
batch_num = args.set_size * 10 // args.batch_size

logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)



def test1( epoch=None):
    net.eval()
    accm.reset()
    for it, (x, y) in enumerate(train_loader0):
        x = x.cuda()
        y = y.cuda()
        outs = net(x,S)
        loss = criterion(outs, y)
        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))

    PRUNED=net.module.get_pruned_size_dep()
    logger.info('reg(dep) {:.4f}'.format(net.module.get_reg_dep().item()))
    logger.info('reg {:.4f}'.format(net.module.get_reg().item()))
    logger.info('nonactive weight {} ({:.4f}%)'.format(str(net.module.get_pruned_weight_sum()),net.module.get_pruned_weight_sum()*100/20040052))
    logger.info('pruned size {}'.format(str(net.module.get_pruned_size())))
    logger.info('pruned size (dep) {}'.format(str(PRUNED)))
    logger.info('speedup in flops {:.4f}'.format(net.module.get_speedup_dep()))
    logger.info('memory saving {:.4f}\n'.format(net.module.get_memory_saving_dep()))
    logger.info('\n' )
    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)

    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#

    if epoch==args.shot-1:
        pruned=[[] for _ in range(len(net.module.gated_layers))]
        for l,i in enumerate(net.module.gated_layers):

            k=i.get_mask_dep().view(-1)
            for idx,el in enumerate(k):
                if el!=0:
                    pruned[l].append(idx)

        return pruned,PRUNED
    else: return None,None

def test( epoch=None):
    net.eval()
    accm.reset()
    tt=0
    for it, (x, y) in enumerate(test_loader):
        x = x.cuda()
        y = y.cuda()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()

        start.record()
        outs = net(x)
        loss = criterion(outs, y)

        # whatever you are timing goes here
        end.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()
        t=start.elapsed_time(end)
        tt+=t
        accm.update([loss.item(), accuracy(outs, y)])
    logger.info(accm.info(header='test', epoch=epoch))

    logger.info("infer {}sec".format(tt/1000))#
    #logger.info("avg {}sec".format(ttt/200000))#

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/time", accm.get('acc'),TIME.get_sum('time'))#




def sample(setsize,N):
    sd=[]
    sdy=[]
    ssd = []
    ssdy = []
    random.shuffle(shuffle)

    for i,c in enumerate(N):
        index=random.sample(cl, setsize)
        for j in index:
            sd.append(d[c][j])
            sdy.append(i)

    for i in shuffle:
        ssd.append(sd[i])
        ssdy.append(sdy[i])

    x=[]
    y=[]
    from_index=0
    for i in range(batch_num):
        end_index=from_index+args.batch_size
        x.append(torch.stack(ssd[from_index:end_index]))
        y.append(torch.tensor(ssdy[from_index:end_index]))
        from_index=end_index
    return x,y

if args.data_id==-1:
    # CIFAR100 Dataset
    train_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=True,
                                  transform=transforms.Compose([
                                                transforms.RandomCrop(32, padding=4),
                                                transforms.RandomHorizontalFlip(),
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                            ]), download=True)
    test_dataset = datasets.CIFAR10(root='../dataset/cifar10', train=False,
                                 transform=transforms.Compose([
                                                transforms.ToTensor(),
                                                transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
                                            ]))

    if args.glance !=50000:
        glace_train,_=torch.utils.data.random_split(train_dataset, [args.glance,50000-args.glance])
    else:
        glace_train=train_dataset
    print(len(glace_train))
    train_loader0 = torch.utils.data.DataLoader(dataset=glace_train, batch_size=args.batch_size, shuffle=True)


    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=args.batch_size, shuffle=True)
    test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=args.batch_size, shuffle=False)
else:
    D=DATASET(idx=args.data_id,minibatch_size=args.batch_size,cls_idx=[0,1,2,3,4,5,6,7,8,9])
    train_loader0,train_loader,test_loader=D.get_loader()
    if args.data_id==0:
        logger.info("mnist")
    else:
        logger.info("svhn")


name1 = "../dataset/cifar10_cl_train"
name2 = "../dataset/cifar10_cl_test"
if not os.path.exists(name1):

    train_set=[[] for _ in range(10)]
    test_set = [[] for _ in range(10)]
    for image, target in train_dataset:
        train_set[target].append(image)
    for image, target in test_dataset:
        test_set[target].append(image)
    with open(name1, "wb") as fp:  # Pickling
        pickle.dump(train_set, fp)

    with open(name2, "wb") as fp:  # Pickling
        pickle.dump(test_set, fp)
    print("Done")


with open(name1, "rb") as fb:
    d = pickle.load(fb)
with open(name2, "rb") as fb:
    t = pickle.load(fb)


net=Net_set(num_classes=10)
net.build_gate(BBDropout, {'a_uc_init': args.a, 'kl_scale': args.kl_scale})
print(net.set_apply())
net.build_gate_dep(DBBDropout,{'kl_scale':args.dkl_scale,'rho':math.sqrt(args.rho)})
ckpt=torch.load(args.pretrain_path)
net.load_state_dict(ckpt['state_dict'])
net=nn.DataParallel(net)
net.cuda()

base_params = []
d_params = []
bb_params = []
set_params=[]
for name, param in net.named_parameters():
    if 'dgate' in name:
        d_params.append(param)
    elif 'gate' in name:
        bb_params.append(param)
    elif 'set_func' in name:
        set_params.append(param)
    else:
        base_params.append(param)
optimizer = torch.optim.Adam([
    {'params': d_params, 'lr': args.dbb_lr},
    {'params': bb_params, 'lr': args.bb_lr},
    {'params': set_params, 'lr': args.set_lr},
    {'params': base_params, 'lr': args.model_lr}])
if args.shot==200:
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
            milestones=[int(r*args.num_epochs) for r in [.3,.6, .8]],
            gamma=0.1)

criterion = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')
TIME = Accumulator('time')
TIME.reset()

net.train()
X,Y=sample(args.set_size,[0,1,2,3,4,5,6,7,8,9])
S=torch.stack(X).cuda()
S=S.view(1,-1,3072)


PRUNED=None
P=None
for i in range(args.shot):
    accm.reset()
    tt=0
    for x, y in train_loader0:
        x = x.cuda()
        y = y.long().cuda()
        optimizer.zero_grad()
        start = torch.cuda.Event(enable_timing=True)
        end = torch.cuda.Event(enable_timing=True)
        torch.cuda.synchronize()
        start.record()
        outs = net(x,S)
        cent = criterion(outs, y)
        loss=cent
        reg = net.module.get_reg_dep().cuda()+net.module.get_reg().cuda()
        loss = loss + args.gamma*reg
        loss.backward()
        optimizer.step()
        # whatever you are timing goes here
        end.record()

        # Waits for everything to finish running
        torch.cuda.synchronize()
        t=start.elapsed_time(end)
        tt+=t


        TIME.update([t/1000])#

        writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#

        accm.update([cent.item(), accuracy(outs, y)])
    line = accm.info(header='train', epoch=i)
    logger.info(line)
    logger.info("{:.4f}".format(tt))
    writer.add_scalar("training/loss", accm.get('cent') , i)
    writer.add_scalar("training/accuracy", accm.get('acc'),i)
    logger.info("Time spent: {}".format(TIME.get_sum('time')))#
    writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#

    P,PRUNED=test1(epoch=i)
    if args.shot==200:
        lr_scheduler.step()
#print(P)
#print(PRUNED)
for i,l in enumerate(P):
    if PRUNED[i]!= len(l):
        print(PRUNED[i],"!=",len(l))
        PRUNED[i]= len(l)

logger.info(" pruned ARCHITECTURE: {}".format(str(PRUNED)))

#logger.info("{:.3f} % MEMORY".format(count_memory(PRUNED)*100/count_memory([64, 64, 128, 128, 256, 256, 256, 256, 512, 512, 512, 512, 512, 512, 512, 512])))


real_res(PRUNED,logger)

net_=RESNET_PRUNED(num_classes=10,P=PRUNED,logger=logger)

#not gated conv
layer_idx=-1
count=False
dict=net_.state_dict()
dict.update({k.replace("module.",""):v for k,v in net.state_dict().items() if 'conv0' in k})
net_.load_state_dict(dict)
for m in net_.modules():
    if isinstance(m, nn.Conv2d)  :
        if layer_idx==-1:
            layer_idx+=1
        elif  m.weight.data.size(3) != 1:
            if layer_idx==0:
                newlayer=[]
                for active_idx in P[layer_idx]:
                    newlayer.append(net.module.gated_layers[layer_idx].base.weight.data[active_idx])
                newlayer=torch.stack(newlayer)
                print(newlayer.size())
                m.weight.data = newlayer.clone()
                if m.bias is not None:
                    newlayer=[]
                    for active_idx in P[layer_idx]:
                        newlayer.append(net.module.gated_layers[layer_idx].base.bias.data[active_idx])
                    newlayer=torch.stack(newlayer)
                    m.bias.data =newlayer.clone()
                layer_idx+=1

            else:
                print(layer_idx)
                newlayer=[]
                print(layer_idx)
                for active_idx in P[layer_idx]:
                    newlayer1=[]
                    for active_d in P[layer_idx-1]:
                        newlayer1.append(net.module.gated_layers[layer_idx].base.weight.data[active_idx][active_d])
                    newlayer1=torch.stack(newlayer1)
                    newlayer.append(newlayer1)
                newlayer=torch.stack(newlayer)
                print(newlayer.size())
                m.weight.data = newlayer.clone()
                if m.bias is not None:
                    newlayer=[]
                    for active_idx in P[layer_idx]:
                        newlayer.append(net.module.gated_layers[layer_idx].base.bias.data[active_idx])
                    newlayer=torch.stack(newlayer)
                    m.bias.data =newlayer.clone()
                layer_idx+=1



net = nn.DataParallel(net_).cuda()
#net=net_.cuda()
print(net)

#optimizer = torch.optim.Adam(net.parameters(), lr=1e-3, weight_decay=1e-4)
optimizer = torch.optim.SGD(net.parameters(), lr=args.lr, momentum=0.9, weight_decay=1e-4,nesterov=True)
if args.adam:
    optimizer = torch.optim.Adam(net.parameters(), lr=args.model_lr, weight_decay=1e-4)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
        milestones=[int(r*args.num_epochs) for r in [.4,.6]],
        gamma=0.1)
criterion = nn.CrossEntropyLoss().cuda()



def train():
    for epoch in range(args.shot, args.num_epochs):
        accm.reset()
        line = '\n epoch {} starts with lr'.format(epoch)
        for pg in optimizer.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)
        net.train()


        for x, y in train_loader:
            x = x.cuda()
            y = y.cuda()
            optimizer.zero_grad()

            start = torch.cuda.Event(enable_timing=True)
            end = torch.cuda.Event(enable_timing=True)

            start.record()
            outs = net(x)
            loss = criterion(outs, y)
            loss.backward()
            optimizer.step()
            # whatever you are timing goes here
            end.record()

            # Waits for everything to finish running
            torch.cuda.synchronize()
            t=start.elapsed_time(end)


            TIME.update([t/1000])#
            writer.add_scalar("training/iter_time", accuracy(outs, y), TIME.get_sum('time'))#


            accm.update([loss.item(), accuracy(outs, y)])
        line = accm.info(header='train', epoch=epoch)
        logger.info(line )
        writer.add_scalar("training/loss", accm.get('cent') , epoch)
        writer.add_scalar("training/accuracy", accm.get('acc'),epoch)
        logger.info("Time spent: {}".format(TIME.get_sum('time')))#
        writer.add_scalar("training/epoch_time", accm.get('acc') , TIME.get_sum('time'))#


        lr_scheduler.step()



        if epoch % args.eval_freq == 0:
            test( epoch=epoch)


    torch.save({'state_dict':net.state_dict()},os.path.join(args.save_dir, logname+"model.tar"))
    writer.close()



if __name__ == '__main__':
    train()
