import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pickle
import glob



def get_model(dir_,weights_dict,bias_dict):

    with open(dir_+"/args.pkl", "rb") as input_file:
        args = pickle.load(input_file)
   

    args.evaluate = True
    train_prec_log = []
    test_prec_log = []

    if args.model_type == "normal":
        import resnet as resnet
    else:
        raise("Something is off")
        
    print(args.save_dir)
    # args.save_dir = args.save_dir+"_"+str(args.model_type)+"_"+str(args.version)+"_"+str(args.Mense)+"_"
    args.resume = os.path.join(args.save_dir, 'model.th')

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    with open(args.save_dir+'/args.pkl', 'wb') as handle:
        pickle.dump(args, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    if args.model_type == "normal" or args.model_type == "drop" :
        model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    else:
        model = torch.nn.DataParallel(resnet.__dict__[args.arch](m=args.Mense))

    model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = 200#checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            # print("=> loaded checkpoint '{}' (epoch {})"
                  # .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    for id,module in enumerate(model.modules()):
        if(type(module) == nn.Conv2d or type(module) == nn.Linear ):
            if id in weights_dict:
                weights_dict[id].append(module.weight.data.clone())
                if module.bias is not None:
                    bias_dict[id].append(module.bias.data.clone())
            else:
                weights_dict[id]=[]
                bias_dict[id]=[]
                weights_dict[id].append(module.weight.data.clone())
                if module.bias is not None:
                    bias_dict[id].append(module.bias.data.clone())
                    
    # print(model.modules[id])
    # print(len(model.modules()))
    return model,args,weights_dict,bias_dict

        # var_inv_sum = 0
        # for bkey in range(self.m):
            # var_inv_sum +=  1/self.mms[bkey].weight[:,:,:,:].var()

        # self.subcnn.weight.data.fill_(0)
        # if self.bias:
            # self.subcnn.bias.data.fill_(0)
        # for bkey in range(self.m):
            # vrinv = (1/self.mms[bkey].weight[:,:,:,:].var())/var_inv_sum
            # self.subcnn.weight[:,:,:,:] += \
            # vrinv*self.mms[bkey].weight[:,:,:,:]
            # if self.bias:
                # self.subcnn.bias[:] += vrinv* self.mms[bkey].bias[:] 
def fuse_models(dir):
    weights_dict = {}
    bias_dict = {}
    bmodel = None
    for d in [0,1,2,4]:
        model,args,weights_dict,bias_dict = get_model(dir.format(d),weights_dict,bias_dict)
        if d == 0:
            bmodel = model
    model = bmodel
        
    #create the target model
    for id,module in enumerate(model.modules()):
        if(type(module) == nn.Conv2d or type(module) == nn.Linear ):
            
            var_inv_sum = 0
            for i in range(4):
                var_inv_sum +=  1/weights_dict[id][i].var()
               
            module.weight.data.fill_(0)               
            for i in range(4):
                vrinv = (1/weights_dict[id][i].var())/var_inv_sum
                module.weight.data+= vrinv*weights_dict[id][i]
            
            if module.bias is not None:
                module.bias.data.fill_(0)
                for i in range(4):
                    vrinv = (1/weights_dict[id][i].var())/var_inv_sum
                    module.bias.data+= vrinv*bias_dict[id][i]

    return model,args,dir
        
    
def main(model,args,dir):

   

    args.evaluate = True

    
    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()


    if args.evaluate:
        test_acc,test_loss = validate(val_loader, model, criterion,args)
        print(dir.format("X"),"Merge acc:",test_acc," |loss:",test_loss)
            
        return




def validate(val_loader, model, criterion,args):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.8f}'
          .format(top1=top1))

    return top1.avg,losses.avg



class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def name_analyzer(dltitle,bargs):
    d = dltitle.replace(bargs.dir,"")
    #Model 	 ./output/resnet_cifar_10/save_resnet110_iea_nn_1_4_ 	 93.92
    ds = d.split("_")
    ds.pop()
    Mense = int(ds[-1].replace(" ",""))
    ds.pop()
    version = ds[-1].replace(" ","")
    ds.pop()
    
    # print("DBG|",ds[-2].replace(" ","")+"_"+ds[-1].replace(" ",""))
    if ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "drop_iea":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "iea_nn":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-1] == "normal":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] =="iea":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "maxout":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "base":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "drop":
        model_title = ds[-1].replace(" ","")
        ds.pop()

    arch = ds[-1]
    ds.pop()
    return Mense,version,model_title,arch

weights_dict={}
bias_dict={}
merged_tasks = [
'./output/resnet_cifar_10/save_resnet20_normal_{0}_4_',
'./output/resnet_cifar_10/save_resnet32_normal_{0}_4_',
'./output/resnet_cifar_10/save_resnet44_normal_{0}_4_',
'./output/resnet_cifar_10/save_resnet56_normal_{0}_4_',
'./output/resnet_cifar_10/save_resnet110_normal_{0}_4_'
]

for i in range(len(merged_tasks)): 
    main(*fuse_models(merged_tasks[i]))
