import math
import os
import random
import sys
import time
import warnings

import numpy
import numpy as np
import torch
import shutil
import logging

import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.nn as nn

warnings.filterwarnings("error")


def timeSince(since=None, s=None):
    if s is None:
        s = int(time.time() - since)
    m = math.floor(s / 60)
    s %= 60
    h = math.floor(m / 60)
    m %= 60
    return '%dh %dm %ds' % (h, m, s)


class AvgrageMeter(object):
    """Computes and stores the average and current value"""

    def __init__(self):
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count


def get_correct_num(y, target):
    pred_label = torch.argmax(y, dim=1)
    return (target == pred_label).sum().item()


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1:y2, x1:x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return train_transform, valid_transform


def _get_cifar10(args):
    train_transform, valid_transform = _data_transforms_cifar10(args)
    train_data = dset.CIFAR10(
        root=args.data, train=True, download=True, transform=train_transform
    )
    valid_data = dset.CIFAR10(
        root=args.data, train=False, download=True, transform=valid_transform
    )

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
    )

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=4,
    )
    return train_queue, valid_queue


def _get_dist_cifar10(args):
    train_transform, valid_transform = _data_transforms_cifar10(args)
    train_data = dset.CIFAR10(
        root=args.data, train=True, download=True, transform=train_transform
    )
    valid_data = dset.CIFAR10(
        root=args.data, train=False, download=True, transform=valid_transform
    )

    sampler = torch.utils.data.distributed.DistributedSampler(
        train_data, num_replicas=args.gpu_num, rank=args.local_rank)

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size // args.gpu_num,
        pin_memory=True,
        num_workers=4,
        drop_last=True,
        sampler=sampler
    )

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=4,
    )
    return train_queue, valid_queue, sampler


def _get_dist_imagenet(args):
    traindir = os.path.join(args.data_dir, 'train')
    valdir = os.path.join(args.data_dir, 'val')
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    train_dataset = dset.ImageFolder(
        traindir,
        transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ColorJitter(
                brightness=0.4,
                contrast=0.4,
                saturation=0.4,
                hue=0.2),
            transforms.ToTensor(),
            normalize,
        ]))

    sampler = torch.utils.data.distributed.DistributedSampler(
        train_dataset, num_replicas=args.gpu_num, rank=args.local_rank)

    train_loader = torch.utils.data.DataLoader(
        train_dataset, batch_size=args.batch_size // args.gpu_num, num_workers=max(args.gpu_num * 2, 4),
        pin_memory=True, drop_last=True, sampler=sampler)

    val_loader = torch.utils.data.DataLoader(
        dset.ImageFolder(valdir, transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ])),
        batch_size=args.batch_size, shuffle=False,
        num_workers=4, pin_memory=True)

    return train_loader, val_loader, sampler


def _data_transforms_cifar100(args):
    CIFAR_MEAN = [0.5070751592371323, 0.48654887331495095, 0.4409178433670343]
    CIFAR_STD = [0.2673342858792401, 0.2564384629170883, 0.27615047132568404]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return train_transform, valid_transform


def _get_cifar100(args):
    train_transform, valid_transform = _data_transforms_cifar100(args)
    train_data = dset.CIFAR100(
        root=args.data, train=True, download=True, transform=train_transform
    )
    valid_data = dset.CIFAR100(
        root=args.data, train=False, download=True, transform=valid_transform
    )

    train_queue = torch.utils.data.DataLoader(
        train_data,
        batch_size=args.batch_size,
        shuffle=True,
        pin_memory=True,
        num_workers=4,
    )

    valid_queue = torch.utils.data.DataLoader(
        valid_data,
        batch_size=args.batch_size,
        shuffle=False,
        pin_memory=True,
        num_workers=4,
    )
    return train_queue, valid_queue


def _get_imagenet_tiny(args):
    traindir = os.path.join(args.data, 'train')
    validdir = os.path.join(args.data, 'val')
    normalize = transforms.Normalize(
        mean=[0.4802, 0.4481, 0.3975],
        std=[0.2302, 0.2265, 0.2262]
    )
    train_transform = transforms.Compose([
        transforms.RandomCrop(64, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        normalize,
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    train_data = dset.ImageFolder(
        traindir,
        train_transform
    )
    valid_data = dset.ImageFolder(
        validdir,
        transforms.Compose([
            transforms.ToTensor(),
            normalize,
        ])
    )

    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=args.batch_size, shuffle=True, pin_memory=True, num_workers=4)

    valid_queue = torch.utils.data.DataLoader(
        valid_data, batch_size=args.batch_size // 2, shuffle=False, pin_memory=True, num_workers=4)
    return train_queue, valid_queue


def count_parameters_in_MB(model):
    return np.sum([np.prod(v.size()) for v in model.parameters()]) / 1e6


def count_parameters(model):
    """
    Get element number of all parameters matrix.
    :param model:
    :return:
    """
    return sum([torch.numel(v) for v in model.parameters()])


def save(model, model_path):
    torch.save(model.state_dict(), model_path)


def load(model, model_path):
    model.load_state_dict(torch.load(model_path))


def load_ckpt(ckpt_path):
    print(f'=> loading checkpoint {ckpt_path}...')
    try:
        checkpoint = torch.load(ckpt_path)
    except:
        print(f"=> fail loading {ckpt_path}...");
        exit()
    return checkpoint


def save_ckpt(ckpt, file_dir, file_name='model.ckpt', is_best=False):
    if not os.path.exists(file_dir): os.makedirs(file_dir)
    ckpt_path = os.path.join(file_dir, file_name)
    torch.save(ckpt, ckpt_path)
    if is_best: shutil.copyfile(ckpt_path, os.path.join(file_dir, f'best_{file_name}'))


def drop_path(x, drop_prob, dims=(0,)):
    var_size = [1 for _ in range(x.dim())]
    for i in dims:
        var_size[i] = x.size(i)
    if drop_prob > 0.:
        keep_prob = 1. - drop_prob
        mask = Variable(torch.cuda.FloatTensor(*var_size).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def create_exp_dir(path, scripts_to_save=None):
    if not os.path.exists(path):
        os.makedirs(path)
    print('Experiment dir : {}'.format(path))

    if scripts_to_save is not None:
        os.makedirs(os.path.join(path, 'tools'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'tools', os.path.basename(script))
            shutil.copyfile(script, dst_file)


class Performance(object):
    def __init__(self, path):
        self.path = path
        self.data = None

    def update(self, alphas_normal, alphas_reduce, val_loss):
        a_normal = F.softmax(alphas_normal, dim=-1)
        # print("alpha normal size: ", a_normal.data.size())
        a_reduce = F.softmax(alphas_reduce, dim=-1)
        # print("alpha reduce size: ", a_reduce.data.size())
        data = np.concatenate([a_normal.data.view(-1),
                               a_reduce.data.view(-1),
                               np.array([val_loss.data])]).reshape(1, -1)
        if self.data is not None:
            self.data = np.concatenate([self.data, data], axis=0)
        else:
            self.data = data

    def save(self):
        np.save(self.path, self.data)


def logger(log_dir, need_time=True, need_stdout=False):
    log = logging.getLogger(__name__)
    log.setLevel(logging.DEBUG)
    fh = logging.FileHandler(log_dir)
    fh.setLevel(logging.DEBUG)
    formatter = logging.Formatter(fmt='%(asctime)s %(message)s', datefmt='%m/%d/%Y-%I:%M:%S')
    if need_stdout:
        ch = logging.StreamHandler(sys.stdout)
        ch.setLevel(logging.DEBUG)
        log.addHandler(ch)
    if need_time:
        fh.setFormatter(formatter)
        if need_stdout:
            ch.setFormatter(formatter)
    log.addHandler(fh)
    return log


class CrossEntropyLabelSmooth(nn.Module):

    def __init__(self, num_classes, epsilon):
        super(CrossEntropyLabelSmooth, self).__init__()
        self.num_classes = num_classes
        self.epsilon = epsilon
        self.logsoftmax = nn.LogSoftmax(dim=1)

    def forward(self, inputs, targets):
        log_probs = self.logsoftmax(inputs)
        targets = torch.zeros_like(log_probs).scatter_(1, targets.unsqueeze(1), 1)
        targets = (1 - self.epsilon) * targets + self.epsilon / self.num_classes
        loss = (-targets * log_probs).mean(0).sum()
        return loss


def roc_auc_compute_fn(y_pred, y_target):
    """ IGNITE.CONTRIB.METRICS.ROC_AUC """
    try:
        from sklearn.metrics import roc_auc_score
    except ImportError:
        raise RuntimeError("This contrib module requires sklearn to be installed.")

    if y_pred.requires_grad:
        y_pred = y_pred.detach()

    if y_target.is_cuda:
        y_target = y_target.cpu()
    if y_pred.is_cuda:
        y_pred = y_pred.cpu()

    y_true = y_target.numpy()
    y_pred = y_pred.numpy()
    try:
        return roc_auc_score(y_true, y_pred)
    except ValueError:
        # print('ValueError: Only one class present in y_true. ROC AUC score is not defined in that case.')
        return 0.


def load_checkpoint(args):
    try:
        return torch.load(args.resume)
    except RuntimeError:
        raise RuntimeError(f"Fail to load checkpoint at {args.resume}")


def save_checkpoint(ckpt, is_best, file_dir, file_name='model.ckpt'):
    if not os.path.exists(file_dir):
        os.makedirs(file_dir)
    ckpt_name = "{0}{1}".format(file_dir, file_name)
    torch.save(ckpt, ckpt_name)
    if is_best: shutil.copyfile(ckpt_name, "{0}{1}".format(file_dir, 'best_' + file_name))


def seed_everything(seed=2022):
    ''' [reference] https://gist.github.com/KirillVladimirov/005ec7f762293d2321385580d3dbe335 '''
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.backends.cudnn.deterministic = True
