import argparse
import os
import shutil
import time
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import pickle
import glob

    
def main(dir_,ven=False,cnt=3):



    global args, best_prec1
    with open(dir_+"/args.pkl", "rb") as input_file:
        args = pickle.load(input_file)
   

    args.evaluate = True
    train_prec_log = []
    test_prec_log = []

    if args.model_type == "normal":
        import resnet as resnet
    elif args.model_type == "iea":
        import resnet_iea  as resnet
    elif args.model_type == "maxout":
        import resnet_maxout  as resnet
    elif args.model_type == "base":
        import resnet_base as resnet
    elif args.model_type == "drop":
        import resnet_drop as resnet
    elif args.model_type == "drop_iea":
        import resnet_drop_iea as resnet
    elif args.model_type == "iea_nn":
        import resnet_iea_nn as resnet
        
    print(args.save_dir)
    # args.save_dir = args.save_dir+"_"+str(args.model_type)+"_"+str(args.version)+"_"+str(args.Mense)+"_"
    args.resume = os.path.join(args.save_dir, 'model.th')

    # Check the save_dir exists or not
    if not os.path.exists(args.save_dir):
        os.makedirs(args.save_dir)
    with open(args.save_dir+'/args.pkl', 'wb') as handle:
        pickle.dump(args, handle, protocol=pickle.HIGHEST_PROTOCOL)
    
    if args.model_type == "normal" or args.model_type == "drop" :
        model = torch.nn.DataParallel(resnet.__dict__[args.arch]())
    else:
        model = torch.nn.DataParallel(resnet.__dict__[args.arch](m=args.Mense))
    # if args.model_type == "normal" or args.model_type == "drop" :
        # model = resnet.__dict__[args.arch]()
    # else:
        # model = resnet.__dict__[args.arch](m=args.Mense)

    model.cuda()

    # optionally resume from a checkpoint
    if args.resume:
        if os.path.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = 200#checkpoint['epoch']
            best_prec1 = checkpoint['best_prec1']
            model.load_state_dict(checkpoint['state_dict'])
            # print("=> loaded checkpoint '{}' (epoch {})"
                  # .format(args.evaluate, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    cnt_=0
    if args.model_type != "maxout":
        for m in model.modules():
            if hasattr(m, "domms"):
                # print("Apply inv variance")
                m.domms = False
                m.apply_weights_pruning()
                cnt_+=1
        print("CNT:",cnt_)
    
    cudnn.benchmark = True

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=True, transform=transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, 4),
            transforms.ToTensor(),
            normalize,
        ]), download=True),
        batch_size=args.batch_size, shuffle=True,
        num_workers=args.workers, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(
        datasets.CIFAR10(root='./data', train=False, transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=128, shuffle=False,
        num_workers=args.workers, pin_memory=True)

    # define loss function (criterion) and pptimizer
    criterion = nn.CrossEntropyLoss().cuda()


    if args.evaluate:
        if not ven:
            test_acc = validate(val_loader, model, criterion)
            print("Before:",best_prec1,"| After:",test_acc)
        else:
            
            test_acc = validate_ensemble(val_loader, model, criterion,cnt)
            print("Ensemble:",test_acc,"using ",cnt," ensembles")
        return


results_save_ensemble = {}

def validate(val_loader, model, criterion):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            target = target.cuda()
            input_var = input.cuda()
            target_var = target.cuda()

            if args.half:
                input_var = input_var.half()

            # compute output
            output = model(input_var)
            loss = criterion(output, target_var)

            output = output.float()
            if i not in results_save_ensemble:
                results_save_ensemble[i] = []
                results_save_ensemble[i].append(output.data.cpu())
            else:
                results_save_ensemble[i].append(output.data.cpu())

            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))

    return top1.avg


def validate_ensemble(val_loader, model, criterion,cnt):
    """
    Run evaluation
    """
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()

    # switch to evaluate mode
    model.eval()

    end = time.time()
    with torch.no_grad():
        for i, (input, target) in enumerate(val_loader):
            
            output = (1/cnt)* results_save_ensemble[i][0]
            for o in range(1,cnt):
                output += (1/cnt)* results_save_ensemble[i][o]

            
            # compute output
            target_var = target.cuda()

            loss = criterion(output.cuda(), target_var)

            output = output.float()

            loss = loss.float()

            # measure accuracy and record loss
            prec1 = accuracy(output.data, target)[0]
            losses.update(loss.item(), input.size(0))
            top1.update(prec1.item(), input.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % args.print_freq == 0:
                print('Test: [{0}/{1}]\t'
                      'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                      'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                      'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                          i, len(val_loader), batch_time=batch_time, loss=losses,
                          top1=top1))

    print(' * Prec@1 {top1.avg:.3f}'
          .format(top1=top1))

    return top1.avg


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def accuracy(output, target, topk=(1,)):
    """Computes the precision@k for the specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


def name_analyzer(dltitle,bargs):
    d = dltitle.replace(bargs.dir,"")
    #Model 	 ./output/resnet_cifar_10/save_resnet110_iea_nn_1_4_ 	 93.92
    ds = d.split("_")
    ds.pop()
    Mense = int(ds[-1].replace(" ",""))
    ds.pop()
    version = ds[-1].replace(" ","")
    ds.pop()
    
    # print("DBG|",ds[-2].replace(" ","")+"_"+ds[-1].replace(" ",""))
    if ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "drop_iea":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","") == "iea_nn":
        model_title = ds[-2].replace(" ","")+"_"+ds[-1].replace(" ","")
        ds.pop()
        ds.pop()
    elif ds[-1] == "normal":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] =="iea":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "maxout":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "base":
        model_title = ds[-1].replace(" ","")
        ds.pop()
    elif ds[-1] == "drop":
        model_title = ds[-1].replace(" ","")
        ds.pop()

    arch = ds[-1]
    ds.pop()
    return Mense,version,model_title,arch
parser = argparse.ArgumentParser()
parser.add_argument('--v', default='10',
                    help='personal tag for the model ')
parser.add_argument('--dir', default='N',
                    help='personal tag for the model ')                    
bargs = parser.parse_args()
bargs.dir = './output/resnet_cifar_{0}/'.format(bargs.v)
dirs = glob.glob(bargs.dir+"*/")



version_accept =[0,1,2,4]# [0,1,2,3,4]
M_accept = [4]#[2,4,8,16]
model_accept =["iea_nn"]# #["iea","iea_nn","maxout","normal"]
arch_accept = ["resnet56"]#["resnet56","resnet110"]
chk_lst = 4 
cnt = 0
for d in dirs: 
    fles = glob.glob(d+"/*")
    Mense,version,model_title,arch = name_analyzer(d,bargs)
    if int(version) not in version_accept:
        continue
    if int(Mense) not in M_accept:
        continue  
    if model_title not in model_accept:
        continue
    if arch not in arch_accept:
        continue
        
        
    if os.path.exists(d+"/epoch.txt"): #New case
        with open(d+'/epoch.txt', 'r') as handle:
            epo = int(handle.readline())
            if epo == 199:
                main(d)
                last_d = d
                cnt+=1

    else:
        if len(fles) >= chk_lst: #old cases
            main(d)
            last_d = d
            cnt+=1
            
main(last_d,True,cnt)

    