import torch
import torchvision
import torchvision.transforms as transforms
import os
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import time




class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


class ProgressMeter(object):
    def __init__(self, num_batches, meters, prefix=""):
        self.batch_fmtstr = self._get_batch_fmtstr(num_batches)
        self.meters = meters
        self.prefix = prefix

    def display(self, batch):
        entries = [self.prefix + self.batch_fmtstr.format(batch)]
        entries += [str(meter) for meter in self.meters]
        print('\t'.join(entries))

    def _get_batch_fmtstr(self, num_batches):
        num_digits = len(str(num_batches // 1))
        fmt = '{:' + str(num_digits) + 'd}'
        return '[' + fmt + '/' + fmt.format(num_batches) + ']'


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res




def validate(val_loader, model, criterion,device,initial = False):
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')

    # switch to evaluate mode
    if initial:
        model.train()
    else:
        model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(val_loader):
            images = images.to(device , non_blocking=True)

            target = target.to(device , non_blocking=True)
            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1,5))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))


        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'
              .format(top1=top1, top5=top5))

    return top1.avg.cpu().data


def loss_net(model,criterion,dataloader,device):
    loss = 0.0
    total = 0.0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device),data[1].to(device)
            outputs = model(images)
            total += labels.size(0)
            temp_loss = criterion(outputs,labels)
            loss += temp_loss
    return loss.cpu().data/total

def acc_net(model,dataloader,device):
    correct = 0
    total = 0
    with torch.no_grad():
        for data in dataloader:
            images, labels = data[0].to(device),data[1].to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return (100 * correct / total)


#transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
#transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

def get_datalaoders(dataset_name,data_dir, batch_size = 128):
    """

    :param dataset_name: string of datasetname
    :param batch_size: int batch_size
    :return: training dataloader, test dataloader
    """
    if dataset_name == 'Cifar10' or dataset_name=='Cifar100':
        transform_train = transforms.Compose([
            transforms.RandomHorizontalFlip(),
            transforms.RandomCrop(32, padding=4),
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        transform_test = transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
        ])

        if dataset_name=='Cifar10':

            trainset = torchvision.datasets.CIFAR10(root=data_dir, train=True,
                                                download=True, transform=transform_train)
            testset = torchvision.datasets.CIFAR10(root=data_dir, train=False,
                                               download=True, transform=transform_test)

        else:
            trainset = torchvision.datasets.CIFAR100(root=data_dir, train=True,
                                                    download=True, transform=transform_train)
            testset = torchvision.datasets.CIFAR100(root=data_dir, train=False,
                                                   download=True, transform=transform_test)
        trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,
                                                  shuffle=True, num_workers=2)

        testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,
                                                 shuffle=False, num_workers=2)
    elif dataset_name=='ImageNet':

        traindir = os.path.join(data_dir, 'train')
        testdir = os.path.join(data_dir, 'val')
        normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                         std=[0.229, 0.224, 0.225])

        train_dataset = datasets.ImageFolder(
            traindir,
            transforms.Compose([
                transforms.RandomResizedCrop(224),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                normalize,
            ]))

        trainloader = torch.utils.data.DataLoader(
            train_dataset, batch_size=256, shuffle=True,
            num_workers=64, pin_memory=True, sampler=None)

        testloader = torch.utils.data.DataLoader(
            datasets.ImageFolder(testdir, transforms.Compose([
                transforms.Resize(256),
                transforms.CenterCrop(224),
                transforms.ToTensor(),
                normalize,
            ])),
            batch_size=256, shuffle=False,
            num_workers=4, pin_memory=True)
    else:
        raise Exception('The dataset you entered does not exists')




    return trainloader,testloader