#metatrin vgg

import torch
import torch.nn as nn

import argparse
import logging
import os
import sys
import pdb
import random
import pickle
from tensorboardX import SummaryWriter

sys.path.append('../')
from util.accumulator import *
from module.bbdrop import BBDropout
from set.dbbdrop import DBBDropout
from vgg import VGG_CIFAR
from set.set import *
from full.vgg import VGG_CIFAR_gated as VGG

parser = argparse.ArgumentParser()
parser.add_argument('--config', default='meta train vgg', type=str)
parser.add_argument('--gpu', default='0', type=str)
parser.add_argument('--seed', default=None, type=int)
parser.add_argument('--test', action='store_true')
parser.add_argument('--num_epochs', default=2000, type=int)
parser.add_argument('--inner_epochs', default=5, type=int)
parser.add_argument('--finetune_epochs', default=5, type=int)

parser.add_argument('--save_freq', default=20, type=int)
parser.add_argument('--eval_freq', default=1, type=int)
parser.add_argument('--batch_size', type=int, default=32, help='batch size')
parser.add_argument('--gamma', default=1./60000, type=float)
parser.add_argument('--kl_scale', default=1.0, type=float)

parser.add_argument('--dkl_scale', default=1.0, type=float)
parser.add_argument('--model_lr', default=1e-3, type=float)
parser.add_argument('--set_lr', default=1e-3, type=float)

parser.add_argument('--bb_lr', default=0.01, type=float)
parser.add_argument('--dbb_lr', default=0.01, type=float)
parser.add_argument('--pretrain_path', default='../full/bb/vgg/_model.tar', type=str)
parser.add_argument('--save_dir', type=str, default='./metatrain', help='savedir name')
parser.add_argument('--id', type=str, default='_', help='savedir name')
parser.add_argument('--load', action='store_true')


parser.add_argument('--t',default=10, type=int, help='task num')
parser.add_argument('--set_size',default=64, type=int, help='task num')
parser.add_argument('--a',default=20, type=int, help='auc')
args, _ = parser.parse_known_args()
#first tune
writer = SummaryWriter(os.path.join('./runs/metatrain',args.id))

def sample2(setsize):
    sd=[]
    sdy=[]
    ssd = []
    ssdy = []
    random.shuffle(shuffle)

    cl=[i for i in range(5000)]
    for i,c in enumerate(dd):
        index=random.sample(cl, setsize)
        for j in index:
            sd.append(c[j])
            sdy.append(i)

    for i in shuffle:
        ssd.append(sd[i])
        ssdy.append(sdy[i])

    x=[]
    y=[]
    from_index=0
    for i in range(batch_num):
        end_index=from_index+args.batch_size
        x.append(torch.stack(ssd[from_index:end_index]))
        y.append(torch.tensor(ssdy[from_index:end_index]))
        from_index=end_index
    return x,y

logname=args.id
if not os.path.isdir(args.save_dir):
    os.makedirs(args.save_dir)
log_format = '%(asctime)s %(message)s'
logging.basicConfig(stream=sys.stdout, level=logging.INFO,
                    format=log_format, datefmt='%m/%d %I:%M:%S %p')
fh = logging.FileHandler(os.path.join(args.save_dir, logname+".log"))
fh.setFormatter(logging.Formatter(log_format))
logger = logging.getLogger()
logger.addHandler(fh)
logger.info(str(args) + '\n')


os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
torch.backends.cudnn.benchmark = True
if args.seed is not None:
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed(args.seed)


l=[i for i in range(100)]
cl=[i for i in range(500)]
shuffle=[i for i in range(args.set_size*10)]
shuffle2=[i for i in range(5000)]


batch_num = args.set_size * 10 // args.batch_size




class MetaLearner(nn.Module):
    def __init__(self):
        super(self.__class__, self).__init__()

        self.net=VGG_CIFAR(logger=logger)
        self.net.build_gate(BBDropout, {'a_uc_init': args.a, 'kl_scale': args.kl_scale})

        dict=torch.load(args.pretrain_path)['state_dict']
        new_dict={}
        for k in dict.keys():
            if 'classifier' not in k and 'gate' not in k :
                kk=k.replace('module.','')
                new_dict[kk]=dict[k]


        dict=self.net.state_dict()
        dict.update(new_dict)
        self.net.load_state_dict(dict)

        print(self.net.set_apply())
        self.net.build_gate_dep(DBBDropout,{'kl_scale':args.dkl_scale})

        """
        pretrained_net=VGG(num_classes=100)
        pretrained_net.build_gate(BBDropout, {'a_uc_init': args.a, 'kl_scale': args.kl_scale})
        ckpt=torch.load(args.pretrain_path)
        pretrained_net.load_state_dict(ckpt['state_dict'],strict=False)
        pdb.set_trace()

        for m_from, m_to in zip(pretrained_net.modules(), self.net.modules()):
            if isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d) :
                m_to.weight.data = m_from.weight.data.clone()
                if m_to.bias is not None:
                    m_to.bias.data = m_from.bias.data.clone()

            if isinstance(m_to, BBDropout):
                m_to.a_uc.data = m_from.a_uc.data.clone()
                m_to.b_uc.data = m_from.b_uc.data.clone()

        print(self.net.set_apply())
        self.net.build_gate_dep(DBBDropout,{'kl_scale':args.dkl_scale})
        """
        base_params = []
        d_params = []
        bb_params=[]
        set_params=[]
        for name, param in self.net.named_parameters():
            if 'dgate' in name:
                d_params.append(param)
            elif 'gate' in name:
                bb_params.append(param)
            elif 'set_func' in name:
                set_params.append(param)
            else:
                base_params.append(param)

        self.optimizer = torch.optim.Adam([
            {'params': d_params, 'lr': args.dbb_lr},
            {'params': bb_params, 'lr': args.bb_lr},
            {'params': set_params, 'lr': args.set_lr},
            {'params': base_params, 'lr': args.model_lr}])
        self.scheduler = torch.optim.lr_scheduler.MultiStepLR(self.optimizer,
                milestones=[int(r*args.num_epochs) for r in [.5, .8]],
                gamma=0.1)


        self.net.cuda()
        self.cent_fn=nn.CrossEntropyLoss().cuda()
        self.accm = Accumulator('cent', 'acc')


    def write_grads( self, sum_grads_pi, dummy_S,dummy_X,dummy_Y):
        hooks = []
        self.net.train()
        for i, v in enumerate(self.net.parameters()):
            def closure():
                ii=i
                return lambda grad: sum_grads_pi[ii]
            pp = v.register_hook(closure())
            hooks.append(pp)

        x,y=next(iter(zip(dummy_X,dummy_Y)))

        x = x.cuda()
        y = y.long().cuda()
        outs = self.net(x,dummy_S)
        self.optimizer.zero_grad()
        dummy_loss = self.cent_fn(outs, y)
        reg = self.net.get_reg_dep().cuda()+self.net.get_reg().cuda()
        dummy_loss = dummy_loss + args.gamma * reg
        dummy_loss.backward()
        self.optimizer.step()

        for h in hooks:
            h.remove()


net_pi = VGG_CIFAR(num_classes=10)
net_pi.build_gate(BBDropout, {'a_uc_init': args.a, 'kl_scale': args.kl_scale})
print("net_pi defined: ",net_pi.set_apply())
net_pi.build_gate_dep(DBBDropout,{'kl_scale':args.dkl_scale})

base_params = []
d_params = []
bb_params = []
set_params=[]
for name, param in net_pi.named_parameters():
    if 'dgate' in name:
        d_params.append(param)
    elif 'gate' in name:
        bb_params.append(param)
    elif 'set_func' in name:
        set_params.append(param)
    else:
        base_params.append(param)
optimizer_pi = torch.optim.Adam([
    {'params': d_params, 'lr': args.dbb_lr},
    {'params': bb_params, 'lr': args.bb_lr},
    {'params': set_params, 'lr': args.set_lr},
    {'params': base_params, 'lr': args.model_lr}])
net_pi.cuda()



cent_fn = nn.CrossEntropyLoss().cuda()
accm = Accumulator('cent', 'acc')

name1 = "../dataset/classtrainimg"
name2 = "../dataset/classtestimg"
with open(name1, "rb") as fb:
    d = pickle.load(fb)
with open(name2, "rb") as fb:
    t = pickle.load(fb)
name1 = "../dataset/cifar10_cl_train"
name2 = "../dataset/cifar10_cl_test"
with open(name1, "rb") as fb:
    dd = pickle.load(fb)
with open(name2, "rb") as fb:
    tt = pickle.load(fb)
XX,YY=sample2(args.set_size)
SS=torch.stack(XX).cuda()
SS=SS.view(1,args.set_size*10,3072)
M=MetaLearner()

def train():

    for epoch in range(1, args.num_epochs+1):
        accm.reset()
        line = 'epoch {} starts with lr'.format(epoch)
        for pg in optimizer_pi.param_groups:
            line += ' {:.3e}'.format(pg['lr'])
        logger.info(line)

        grads_pi = None
        net_pi.train()
        X=None
        Y=None
        S=None

        for t in range(args.t):
            N = random.sample(l, 10)
            update_pi(M.net)
            line ="epoch "+str(epoch)+" sampled " +str(t) + "th task :  " + str(N)
            logger.info(line + '\n')

            for steps in range(args.inner_epochs): #inner epoch=instance sampling num
                X,Y=sample(args.set_size,N)
                S=torch.stack(X).cuda()
                S=S.view(1,args.set_size*len(N),3072)
                accm.reset()
                for x, y in zip(X,Y):
                    x = x.cuda()
                    y = y.long().cuda()
                    optimizer_pi.zero_grad()
                    outs = net_pi(x,S)
                    cent = cent_fn(outs, y)
                    loss=cent
                    reg = net_pi.get_reg_dep().cuda()+net_pi.get_reg().cuda()
                    loss = loss + args.gamma*reg
                    if steps==args.inner_epochs-1:
                        loss.backward(retain_graph=True)

                    else:
                        loss.backward()

                    optimizer_pi.step()
                    accm.update([cent.item(), accuracy(outs, y)])

                line = "inner pi update "+accm.info(header='train', epoch=steps)
                logger.info(line )
                logger.info("task reg: {}".format(net_pi.get_reg().item()))
                logger.info("task reg(dep): {}".format(net_pi.get_reg_dep().item()))


            if grads_pi==None:
                grads_pi = torch.autograd.grad(loss, net_pi.parameters())
            else:
                new_grads_pi = torch.autograd.grad(loss, net_pi.parameters())
                grads_pi = [torch.add(i, j) for i, j in zip(grads_pi, new_grads_pi)]


        M.write_grads(grads_pi,S,X,Y)
        M.scheduler.step()
        if epoch % args.eval_freq == 0:
            in_test(args.set_size,N,epoch)
        if epoch % args.save_freq == 0:
            torch.save({'state_dict':M.net.state_dict()}, os.path.join(args.save_dir,logname+str(epoch)+'.tar'))
    torch.save({'state_dict':M.net.state_dict()}, os.path.join(args.save_dir, logname))

    in_test(5, l,epoch)





def in_test(set_size,N,epoch):


    update_pi(M.net)
    net_pi.train()
    for i in range(3):
        accm.reset()
        for x, y in zip(XX,YY): #fixed reduced situation see trend
            x = x.cuda()
            y = y.long().cuda()
            optimizer_pi.zero_grad()
            outs = net_pi(x,SS)
            cent = cent_fn(outs, y)
            loss=cent
            reg = net_pi.get_reg_dep().cuda()+net_pi.get_reg().cuda()
            loss = loss + args.gamma*reg
            loss.backward()
            optimizer_pi.step()
            accm.update([cent.item(), accuracy(outs, y)])
        logger.info(accm.info(header='in_test', epoch=epoch))

    net_pi.eval()# mask to 0,1  good performance if converged
    for x, y in zip(XX,YY): #fixed reduced situation see trend
        x = x.cuda()
        y = y.long().cuda()
        outs = net_pi(x,SS)
        cent = cent_fn(outs, y)
        loss=cent
        accm.update([cent.item(), accuracy(outs, y)])
    PRUNED=net_pi.get_pruned_size_dep()
    logger.info(accm.info(header='test', epoch=epoch))
    logger.info('reg(dep) {:.4f}'.format(net_pi.get_reg_dep().item()))
    logger.info('reg {:.4f}'.format(net_pi.get_reg().item()))
    logger.info('nonactive weight {} ({:.4f}%)'.format(str(net_pi.get_pruned_weight_sum()),net_pi.get_pruned_weight_sum()*100/20086692))
    logger.info('pruned size {}'.format(str(net_pi.get_pruned_size())))
    logger.info('pruned size (dep) {}'.format(str(PRUNED)))
    logger.info('speedup in flops {:.4f}'.format(net_pi.get_speedup_dep()))
    logger.info('memory saving {:.4f}\n'.format(net_pi.get_memory_saving_dep()))

    writer.add_scalar("validation/loss",accm.get('cent') , epoch)
    writer.add_scalar("validation/accuracy", accm.get('acc'),epoch)
    writer.add_scalar("validation/nonactive_weights", net_pi.get_pruned_weight_sum()*100/20086692,epoch)
    writer.add_scalar("validation/flop_speedup", net_pi.get_speedup_dep(),epoch)
    writer.add_scalar("validation/memory_saved", net_pi.get_memory_saving_dep(),epoch)



def sample(setsize,N):
    sd=[]
    sdy=[]
    ssd = []
    ssdy = []
    random.shuffle(shuffle)

    for i,c in enumerate(N):
        index=random.sample(cl, setsize)
        for j in index:
            sd.append(d[c][j])
            sdy.append(i)

    for i in shuffle:
        ssd.append(sd[i])
        ssdy.append(sdy[i])

    x=[]
    y=[]
    from_index=0
    for i in range(batch_num):
        end_index=from_index+args.batch_size
        x.append(torch.stack(ssd[from_index:end_index]))
        y.append(torch.tensor(ssdy[from_index:end_index]))
        from_index=end_index
    return x,y


def sample2(setsize,N):
    sd=[]
    sdy=[]
    ssd = []
    ssdy = []
    random.shuffle(shuffle)

    for i,c in enumerate(N):
        for j in cl:
            sd.append(d[c][j])
            sdy.append(i)

    for i in shuffle2:
        ssd.append(sd[i])
        ssdy.append(sdy[i])

    x=[]
    y=[]
    from_index=0
    for i in range(5000//args.batch_size):
        end_index=from_index+args.batch_size
        x.append(torch.stack(ssd[from_index:end_index]))
        y.append(torch.tensor(ssdy[from_index:end_index]))
        from_index=end_index
    return x,y



def update_pi(net):
    for m_from, m_to in zip(net.modules(), net_pi.modules()):
        if isinstance(m_to, nn.Conv2d) or isinstance(m_to, nn.BatchNorm2d) or isinstance(m_to, nn.BatchNorm1d):
            m_to.weight.data = m_from.weight.data.clone()
            if m_to.bias is not None:
                m_to.bias.data = m_from.bias.data.clone()
        if isinstance(m_to, DBBDropout):
            m_to.sigma_uc.data = m_from.sigma_uc.data.clone()
            m_to.r.data = m_from.r.data.clone()
            m_to.b.data = m_from.b.data.clone()
        if isinstance(m_to, BBDropout):
            m_to.a_uc.data = m_from.a_uc.data.clone()
            m_to.b_uc.data = m_from.b_uc.data.clone()
        # @for set Transformer
        if isinstance(m_to,ISAB):
            m_to.I.data = m_from.I.data.clone()
        if isinstance(m_to,PMA):
            m_to.S.data = m_from.S.data.clone()

if __name__ == '__main__':
    train()
    writer.close()
