import os
import sys
import re
import datetime

import numpy
import math

import torch
from torch.optim.lr_scheduler import _LRScheduler
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
# from dataset import CIFAR100Train, CIFAR100Test
# from cifar_dataset import CIFAR100
from dataset_cifar import CIFAR100, CIFAR100_META
from dataset_breeds import BREEDS, BREEDS_META
from dataset_tiered_imagenet import TieredImageNet, TieredImageNet_Meta
import random
from PIL import ImageFilter


def get_network(args):
    """ return given network
    """

    if args.net == 'vgg16':
        from models.vgg import vgg16_bn
        net = vgg16_bn()
    elif args.net == 'vgg13':
        from models.vgg import vgg13_bn
        net = vgg13_bn()
    elif args.net == 'vgg11':
        from models.vgg import vgg11_bn
        net = vgg11_bn()
    elif args.net == 'vgg19':
        from models.vgg import vgg19_bn
        net = vgg19_bn()
    elif args.net == 'densenet121':
        from models.densenet import densenet121
        net = densenet121()
    elif args.net == 'densenet161':
        from models.densenet import densenet161
        net = densenet161()
    elif args.net == 'densenet169':
        from models.densenet import densenet169
        net = densenet169()
    elif args.net == 'densenet201':
        from models.densenet import densenet201
        net = densenet201()
    elif args.net == 'googlenet':
        from models.googlenet import googlenet
        net = googlenet()
    elif args.net == 'inceptionv3':
        from models.inceptionv3 import inceptionv3
        net = inceptionv3()
    elif args.net == 'inceptionv4':
        from models.inceptionv4 import inceptionv4
        net = inceptionv4()
    elif args.net == 'inceptionresnetv2':
        from models.inceptionv4 import inception_resnet_v2
        net = inception_resnet_v2()
    elif args.net == 'xception':
        from models.xception import xception
        net = xception()
    elif args.net == 'resnet18':
        from models.resnet import resnet18
        net = resnet18()
    elif args.net == 'resnet34':
        from models.resnet import resnet34
        net = resnet34()
    elif args.net == 'resnet50':
        from models.resnet import resnet50
        net = resnet50()
    elif args.net == 'resnet101':
        from models.resnet import resnet101
        net = resnet101()
    elif args.net == 'resnet152':
        from models.resnet import resnet152
        net = resnet152()
    elif args.net == 'preactresnet18':
        from models.preactresnet import preactresnet18
        net = preactresnet18()
    elif args.net == 'preactresnet34':
        from models.preactresnet import preactresnet34
        net = preactresnet34()
    elif args.net == 'preactresnet50':
        from models.preactresnet import preactresnet50
        net = preactresnet50()
    elif args.net == 'preactresnet101':
        from models.preactresnet import preactresnet101
        net = preactresnet101()
    elif args.net == 'preactresnet152':
        from models.preactresnet import preactresnet152
        net = preactresnet152()
    elif args.net == 'resnext50':
        from models.resnext import resnext50
        net = resnext50()
    elif args.net == 'resnext101':
        from models.resnext import resnext101
        net = resnext101()
    elif args.net == 'resnext152':
        from models.resnext import resnext152
        net = resnext152()
    elif args.net == 'shufflenet':
        from models.shufflenet import shufflenet
        net = shufflenet()
    elif args.net == 'shufflenetv2':
        from models.shufflenetv2 import shufflenetv2
        net = shufflenetv2()
    elif args.net == 'squeezenet':
        from models.squeezenet import squeezenet
        net = squeezenet()
    elif args.net == 'mobilenet':
        from models.mobilenet import mobilenet
        net = mobilenet()
    elif args.net == 'mobilenetv2':
        from models.mobilenetv2 import mobilenetv2
        net = mobilenetv2()
    elif args.net == 'nasnet':
        from models.nasnet import nasnet
        net = nasnet()
    elif args.net == 'attention56':
        from models.attention import attention56
        net = attention56()
    elif args.net == 'attention92':
        from models.attention import attention92
        net = attention92()
    elif args.net == 'seresnet18':
        from models.senet import seresnet18
        net = seresnet18()
    elif args.net == 'seresnet34':
        from models.senet import seresnet34
        net = seresnet34()
    elif args.net == 'seresnet50':
        from models.senet import seresnet50
        net = seresnet50()
    elif args.net == 'seresnet101':
        from models.senet import seresnet101
        net = seresnet101()
    elif args.net == 'seresnet152':
        from models.senet import seresnet152
        net = seresnet152()
    elif args.net == 'wideresnet':
        from models.wideresidual import wideresnet
        net = wideresnet()
    elif args.net == 'stochasticdepth18':
        from models.stochasticdepth import stochastic_depth_resnet18
        net = stochastic_depth_resnet18()
    elif args.net == 'stochasticdepth34':
        from models.stochasticdepth import stochastic_depth_resnet34
        net = stochastic_depth_resnet34()
    elif args.net == 'stochasticdepth50':
        from models.stochasticdepth import stochastic_depth_resnet50
        net = stochastic_depth_resnet50()
    elif args.net == 'stochasticdepth101':
        from models.stochasticdepth import stochastic_depth_resnet101
        net = stochastic_depth_resnet101()

    else:
        print('the network name you have entered is not supported yet')
        sys.exit()

    if args.gpu: #use_gpu
        net = net.cuda()

    return net


def get_training_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True, coarse=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 training dataset
        std: std of cifar100 training dataset
        path: path to cifar100 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.RandomRotation(15),
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    # path = 'data/cifar-100-python'
    # cifar100_training = CIFAR100Train(path, transform=transform_train)
    # cifar100_training = torchvision.datasets.CIFAR100(root='./data', train=True, download=True, transform=transform_train)
    cifar100_training = CIFAR100(root='./data', train=True, download=True, transform=transform_train, coarse=coarse)
    cifar100_training_loader = DataLoader(
        cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_training_loader

def get_test_dataloader(mean, std, batch_size=16, num_workers=2, shuffle=True, coarse=True):
    """ return training dataloader
    Args:
        mean: mean of cifar100 test dataset
        std: std of cifar100 test dataset
        path: path to cifar100 test python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar100_test_loader:torch dataloader object
    """

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
    # path = 'data/cifar-100-python'
    # cifar100_test = CIFAR100Test(path, transform=transform_test)
    # cifar100_test = torchvision.datasets.CIFAR100(root='./data', train=False, download=True, transform=transform_test)
    cifar100_test = CIFAR100(root='./data', train=False, download=True, transform=transform_test, coarse=coarse)
    cifar100_test_loader = DataLoader(
        cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar100_test_loader


# def get_training_dataloader_cifar100(batch_size=16, num_workers=2, shuffle=True, twocrops=False):
#     transform_train = transforms.Compose([
#         # transforms.ToPILImage(),
#         transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
#         transforms.RandomHorizontalFlip(),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
#     ])
#
#     if twocrops is True:
#         transform_train = TwoCropsTransform(transform_train)
#
#     cifar100_training = CIFAR100(root='./data',
#                                  train=True,
#                                  download=True,
#                                  transform=transform_train,
#                                  val=False,
#                                  seed=1000,
#                                  tr_ratio=0.9)
#     cifar100_training_loader = DataLoader(cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
#
#     return cifar100_training_loader


# def get_validation_dataloader_cifar100(batch_size=16, num_workers=2, shuffle=True):
#     transform_validation = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
#     ])
#
#     cifar100_validation = CIFAR100(root='./data',
#                                    train=True,
#                                    download=True,
#                                    transform=transform_validation,
#                                    val=True,
#                                    seed=1000,
#                                    tr_ratio=0.9)
#     cifar100_validation_loader = DataLoader(cifar100_validation, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
#
#     return cifar100_validation_loader


# def get_test_dataloader_cifar100(batch_size=16, num_workers=2, shuffle=True):
#     transform_test = transforms.Compose([
#         transforms.ToTensor(),
#         transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
#     ])
#
#     cifar100_test = CIFAR100(root='./data',
#                              train=False,
#                              download=True,
#                              transform=transform_test,
#                              val=False,
#                              seed=1000,
#                              tr_ratio=0.9)
#     cifar100_test_loader = DataLoader(cifar100_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)
#
#     return cifar100_test_loader


def get_training_dataloader_cifar100(batch_size=16, num_workers=2, shuffle=True, twocrops=False):
    transform_train = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.RandomResizedCrop(32, scale=(0.2, 1.)),
        transforms.RandomHorizontalFlip(),
        # transforms.RandomApply([
        #     transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        # ], p=0.8),
        # transforms.RandomGrayscale(p=0.2),
        # transforms.RandomApply([ancor.loader.GaussianBlur([.1, 2.])], p=0.5),
        transforms.ToTensor(),
        transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
    ])

    if twocrops is True:
        transform_train = TwoCropsTransform(transform_train)

    cifar100_training = CIFAR100(root='./data',
                                 train=True,
                                 transform=transform_train,
                                 target_transform=None,
                                 download=True,
                                 val=False,
                                 seed=1000,
                                 tr_ratio=0.9)
    cifar100_training_loader = DataLoader(cifar100_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=True)

    return cifar100_training_loader


def get_validation_dataloader_cifar100(batch_size=16, num_workers=2, shuffle=True):
    transform_validation = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
    ])

    cifar100_validation = CIFAR100(root='./data',
                                   train=True,
                                   transform=transform_validation,
                                   target_transform=None,
                                   download=True,
                                   val=True,
                                   seed=1000,
                                   tr_ratio=0.9)
    cifar100_validation_loader = DataLoader(cifar100_validation, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return cifar100_validation_loader


def get_test_dataloader_cifar100(n_test_runs, n_ways, n_shots, n_queries, n_aug_support_samples, fg=False, batch_size=1, num_workers=0):
    transform_meta_train = transforms.Compose([
        transforms.ToPILImage(),
        transforms.RandomCrop(32),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
    ])

    transform_meta_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize([0.5071, 0.4867, 0.4408], [0.2675, 0.2565, 0.2761])
    ])

    cifar100_test = CIFAR100_META(data_root='./data',
                                  n_test_runs=n_test_runs,
                                  n_ways=n_ways,
                                  n_shots=n_shots,
                                  n_queries=n_queries,
                                  n_aug_support_samples=n_aug_support_samples,
                                  meta_train=False,
                                  meta_val=False,
                                  transform_train=transform_meta_train,
                                  transform_test=transform_meta_test,
                                  fg=fg,
                                  fix_seed=True)

    cifar100_test_loader = DataLoader(cifar100_test, shuffle=False, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return cifar100_test_loader


def get_training_dataloader_breeds(ds_name, info_dir, data_dir, batch_size=16, num_workers=2, shuffle=True, twocrops=False):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(224, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomApply([GaussianBlur([.1, 2.])], p=0.5),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
    ])

    if twocrops is True:
        transform_train = TwoCropsTransform(transform_train)

    breeds_training = BREEDS(info_dir=info_dir,
                             data_dir=data_dir,
                             ds_name=ds_name,
                             partition='train',
                             split=None,
                             transform=transform_train,
                             train=True,
                             seed=1000,
                             tr_ratio=0.9)
    breeds_training_loader = DataLoader(breeds_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return breeds_training_loader


def get_validation_dataloader_breeds(ds_name, info_dir, data_dir, batch_size=16, num_workers=2, shuffle=True):
    transform_validation = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
    ])

    breeds_validation = BREEDS(info_dir=info_dir,
                               data_dir=data_dir,
                               ds_name=ds_name,
                               partition='train',
                               split=None,
                               transform=transform_validation,
                               train=False,
                               seed=1000,
                               tr_ratio=0.9)
    breeds_validation_loader = DataLoader(breeds_validation, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return breeds_validation_loader


# def get_test_dataloader_breeds(ds_name, info_dir, data_dir, batch_size=16, num_workers=2, shuffle=True):
#     transform_test = transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
#     ])
#
#     # transform_test = transforms.Compose([
#     #     transforms.Resize(256),
#     #     transforms.RandomCrop(224),
#     #     transforms.RandomHorizontalFlip(),
#     #     transforms.ToTensor(),
#     #     transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
#     # ])
#
#     breeds_test = BREEDS(info_dir=info_dir,
#                          data_dir=data_dir,
#                          ds_name=ds_name,
#                          partition='val',
#                          split=None,
#                          transform=transform_test)
#     breeds_test_loader = DataLoader(breeds_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=False)
#
#     return breeds_test_loader


def get_test_dataloader_breeds(ds_name, info_dir, data_dir, n_test_runs, n_ways, n_shots, n_queries, n_aug_support_samples, fg=False, batch_size=1, num_workers=0):
    transform_meta_train = transforms.Compose([
        transforms.Resize(256),
        transforms.RandomCrop(224),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
    ])

    transform_meta_test = transforms.Compose([
        transforms.Resize(256),
        transforms.CenterCrop(224),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
    ])

    breeds_test = BREEDS_META(info_dir=info_dir,
                              data_dir=data_dir,
                              ds_name=ds_name,
                              split=None,
                              transform_train=transform_meta_train,
                              transform_test=transform_meta_test,
                              n_test_runs=n_test_runs,
                              n_ways=n_ways,
                              n_shots=n_shots,
                              n_queries=n_queries,
                              n_aug_support_samples=n_aug_support_samples,
                              fg=fg,
                              fix_seed=True)

    breeds_test_loader = DataLoader(breeds_test, shuffle=False, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return breeds_test_loader


# def get_test_dataloader_breeds(ds_name, info_dir, data_dir, batch_size=16, num_workers=2, shuffle=True):
#     transform_test = transforms.Compose([
#         transforms.Resize(256),
#         transforms.CenterCrop(224),
#         transforms.ToTensor(),
#         transforms.Normalize(mean=[0.4717, 0.4499, 0.3837], std=[0.2600, 0.2516, 0.2575])
#     ])
#
#     breeds_test = BREEDS(info_dir=info_dir,
#                          data_dir=data_dir,
#                          ds_name=ds_name,
#                          partition='test',
#                          split=None,
#                          transform=transform_test)
#     breeds_test_loader = DataLoader(breeds_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=False)
#
#     return breeds_test_loader


def get_training_dataloader_tieredimagenet(batch_size=16, num_workers=2, shuffle=True, twocrops=False):
    transform_train = transforms.Compose([
        transforms.RandomResizedCrop(84, scale=(0.2, 1.)),
        transforms.RandomApply([
            transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
        ], p=0.8),
        transforms.RandomGrayscale(p=0.2),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4783, 0.4564, 0.4102], std=[0.2794756041544754, 0.27362885638931284, 0.28587171211901724])
    ])

    if twocrops is True:
        transform_train = TwoCropsTransform(transform_train)

    tieredimagenet_training = TieredImageNet(root='/nfs/data/usr/jni/datasets/tiered_imagenet',
                                             partition='train',
                                             transform=transform_train,
                                             target_transform=None)

    tieredimagenet_training_loader = DataLoader(tieredimagenet_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size, drop_last=True)

    return tieredimagenet_training_loader


def get_validation_dataloader_tieredimagenet(n_test_runs, n_ways, n_shots, n_queries, n_aug_support_samples, fg=False, batch_size=1, num_workers=0):
    transform_meta_train = transforms.Compose([
        transforms.RandomCrop(84, padding=8),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4783, 0.4564, 0.4102], std=[0.2794756041544754, 0.27362885638931284, 0.28587171211901724])
    ])

    transform_meta_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4783, 0.4564, 0.4102], std=[0.2794756041544754, 0.27362885638931284, 0.28587171211901724])
    ])

    tieredimagenet_validation = TieredImageNet_Meta(data_root='/nfs/data/usr/jni/datasets/tiered_imagenet',
                                                    n_test_runs=n_test_runs,
                                                    n_ways=n_ways,
                                                    n_shots=n_shots,
                                                    n_queries=n_queries,
                                                    n_aug_support_samples=n_aug_support_samples,
                                                    partition='validation',
                                                    transform_train=transform_meta_train,
                                                    transform_test=transform_meta_test,
                                                    fg=fg,
                                                    fix_seed=True)

    tieredimagenet_validation_loader = DataLoader(tieredimagenet_validation, shuffle=False, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return tieredimagenet_validation_loader


def get_test_dataloader_tieredimagenet(n_test_runs, n_ways, n_shots, n_queries, n_aug_support_samples, fg=False, batch_size=1, num_workers=0):
    transform_meta_train = transforms.Compose([
        transforms.RandomCrop(84, padding=8),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4783, 0.4564, 0.4102], std=[0.2794756041544754, 0.27362885638931284, 0.28587171211901724])
    ])

    transform_meta_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.4783, 0.4564, 0.4102], std=[0.2794756041544754, 0.27362885638931284, 0.28587171211901724])
    ])

    tieredimagenet_test = TieredImageNet_Meta(data_root='/nfs/data/usr/jni/datasets/tiered_imagenet',
                                              n_test_runs=n_test_runs,
                                              n_ways=n_ways,
                                              n_shots=n_shots,
                                              n_queries=n_queries,
                                              n_aug_support_samples=n_aug_support_samples,
                                              partition='test',
                                              transform_train=transform_meta_train,
                                              transform_test=transform_meta_test,
                                              fg=fg,
                                              fix_seed=True)

    tieredimagenet_test_loader = DataLoader(tieredimagenet_test, shuffle=False, num_workers=num_workers, batch_size=batch_size, drop_last=False)

    return tieredimagenet_test_loader


def get_training_dataloader_cifar10(batch_size=16, num_workers=2, shuffle=True):
    """ return training dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: train_data_loader:torch dataloader object
    """

    transform_train = transforms.Compose([
        # transforms.ToPILImage(),
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    cifar10_training = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    cifar10_training_loader = DataLoader(cifar10_training, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar10_training_loader

def get_test_dataloader_cifar10(batch_size=16, num_workers=2, shuffle=True):
    """ return test dataloader
    Args:
        mean: mean of cifar10 training dataset
        std: std of cifar10 training dataset
        path: path to cifar10 training python dataset
        batch_size: dataloader batchsize
        num_workers: dataloader num_works
        shuffle: whether to shuffle
    Returns: cifar10_test_loader:torch dataloader object
    """

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
    ])
    cifar10_test = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    cifar10_test_loader = DataLoader(cifar10_test, shuffle=shuffle, num_workers=num_workers, batch_size=batch_size)

    return cifar10_test_loader


class TwoCropsTransform:
    """Take two random crops of one image as the query and key."""

    def __init__(self, base_transform):
        self.base_transform = base_transform

    def __call__(self, x):
        q = self.base_transform(x)
        k = self.base_transform(x)
        return [q, k]


class GaussianBlur(object):
    """Gaussian blur augmentation in SimCLR https://arxiv.org/abs/2002.05709"""

    def __init__(self, sigma=[.1, 2.]):
        self.sigma = sigma

    def __call__(self, x):
        sigma = random.uniform(self.sigma[0], self.sigma[1])
        x = x.filter(ImageFilter.GaussianBlur(radius=sigma))
        return x


def compute_mean_std(cifar100_dataset):
    """compute the mean and std of cifar100 dataset
    Args:
        cifar100_training_dataset or cifar100_test_dataset
        witch derived from class torch.utils.data

    Returns:
        a tuple contains mean, std value of entire dataset
    """

    data_r = numpy.dstack([cifar100_dataset[i][1][:, :, 0] for i in range(len(cifar100_dataset))])
    data_g = numpy.dstack([cifar100_dataset[i][1][:, :, 1] for i in range(len(cifar100_dataset))])
    data_b = numpy.dstack([cifar100_dataset[i][1][:, :, 2] for i in range(len(cifar100_dataset))])
    mean = numpy.mean(data_r), numpy.mean(data_g), numpy.mean(data_b)
    std = numpy.std(data_r), numpy.std(data_g), numpy.std(data_b)

    return mean, std

class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """
    def __init__(self, optimizer, total_iters, last_epoch=-1):

        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]


def most_recent_folder(net_weights, fmt):
    """
        return most recent created folder under net_weights
        if no none-empty folder were found, return empty folder
    """
    # get subfolders in net_weights
    folders = os.listdir(net_weights)

    # filter out empty folders
    folders = [f for f in folders if len(os.listdir(os.path.join(net_weights, f)))]
    if len(folders) == 0:
        return ''

    # sort folders by folder created time
    folders = sorted(folders, key=lambda f: datetime.datetime.strptime(f, fmt))
    return folders[-1]

def most_recent_weights(weights_folder):
    """
        return most recent created weights file
        if folder is empty return empty string
    """
    weight_files = os.listdir(weights_folder)
    if len(weights_folder) == 0:
        return ''

    regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'

    # sort files by epoch
    weight_files = sorted(weight_files, key=lambda w: int(re.search(regex_str, w).groups()[1]))

    return weight_files[-1]

def last_epoch(weights_folder):
    weight_file = most_recent_weights(weights_folder)
    if not weight_file:
       raise Exception('no recent weights were found')
    resume_epoch = int(weight_file.split('-')[1])

    return resume_epoch

def best_acc_weights(weights_folder):
    """
        return the best acc .pth file in given folder, if no
        best acc weights file were found, return empty string
    """
    files = os.listdir(weights_folder)
    if len(files) == 0:
        return ''

    regex_str = r'([A-Za-z0-9]+)-([0-9]+)-(regular|best)'
    best_files = [w for w in files if re.search(regex_str, w).groups()[2] == 'best']
    if len(best_files) == 0:
        return ''

    best_files = sorted(best_files, key=lambda w: int(re.search(regex_str, w).groups()[1]))
    return best_files[-1]


def adjust_learning_rate_cos(optimizer, lr, epoch, num_epochs, num_cycles):
    """Decay the learning rate based on schedule"""
    epochs_per_cycle = math.floor(num_epochs / num_cycles)
    lr *= 0.5 * (1. + math.cos(math.pi * (epoch % epochs_per_cycle) / epochs_per_cycle))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return None


def adjust_learning_rate_multistep(optimizer, lr, epoch, milestones, gamma):
    """Decay the learning rate based on schedule"""
    for step in milestones:
        lr *= gamma if epoch >= step else 1.

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return None


def warmup_learning_rate(optimizer, base_lr, last_iter, total_iters):
    lr = base_lr * last_iter / (total_iters + 1e-8)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

    return None


def write_values(vals, destpath):
    w = open(destpath, 'w', encoding='utf-8')
    for i in range(len(vals)):
        w.write('%.6f\n' % vals[i])

    w.close()

    return None
