import os
import random
import shutil

import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
from torch.optim.lr_scheduler import _LRScheduler
from pytorch_cam.utils.image import show_cam_on_image
import torchvision.transforms.functional as TF
from torchvision.utils import make_grid
import torch.nn as nn


def show_one_row(axs, imgs, title):
    img = imgs.detach()
    img = TF.to_pil_image(img)
    axs.imshow(np.asarray(img))
    axs.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], ylabel=title, aspect="auto")


def show_one_row_cam(axs, imgs, masks, title):
    img = imgs.detach().cpu().numpy()
    img = np.transpose(img, (1, 2, 0))
    mask = masks.detach().cpu().numpy()
    mask = np.transpose(mask, (1, 2, 0))
    cam_with_img = show_cam_on_image(img, mask, True)
    axs.imshow(cam_with_img)
    axs.set(xticklabels=[], yticklabels=[], xticks=[], yticks=[], ylabel=title, aspect="auto")


def show_summary(x, mixed_x, index_list, tx_list, s_list, k_masks, lam_list, bs, save_path=None, prefix=''):
    k = len(index_list)

    # handle x
    x_grids = []
    for index in index_list:
        x_grid = make_grid(x[index], nrow=bs, normalize=True)
        x_grids.append(x_grid)

    # handle tx, s, and m
    tx_grids, s_grids, m_grids = [], [], []
    for i in range(k):
        tx_grid = make_grid(tx_list[i], nrow=bs, normalize=True)
        s_grid = make_grid(s_list[i], nrow=bs, normalize=True)
        m_grid = make_grid(k_masks[i], nrow=bs, normalize=True)
        tx_grids.append(tx_grid)
        s_grids.append(s_grid)
        m_grids.append(m_grid)

    mixed_grid = make_grid(mixed_x, nrow=bs, normalize=True)
    fix, axs = plt.subplots(nrows=4*k+1, ncols=1, squeeze=True, figsize=(16, 2*(4*k+1)))
    plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)

    for i, grid in enumerate(x_grids):
        show_one_row(axs[i], grid, f'x{i}')

    for i, (x_grid, s_grid) in enumerate(zip(x_grids, s_grids)):
        show_one_row_cam(axs[k+i], x_grid, s_grid, 's')

    for i, grid in enumerate(tx_grids):
        show_one_row(axs[2*k+i], grid, f'tx{i}')

    for i, (x_grid, s_grid) in enumerate(zip(tx_grids, m_grids)):
        show_one_row_cam(axs[3*k+i], x_grid, s_grid, 'm')

    show_one_row(axs[-1], mixed_grid, f'KD Mix ({lam_list})')

    if save_path is not None:
        plt.savefig(os.path.join(save_path, f'{prefix}examples.png'))


def cutmix(x, y, alpha=0.5, use_cuda=True):
    '''CutMix'''
    batch_size, _, height, width = x.shape
    ratio = np.zeros([batch_size])

    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    rx = np.random.uniform(0, height)
    ry = np.random.uniform(0, width)
    rh = np.sqrt(1 - alpha) * height
    rw = np.sqrt(1 - alpha) * width
    x1 = int(np.clip(rx - rh / 2, a_min=0., a_max=height))
    x2 = int(np.clip(rx + rh / 2, a_min=0., a_max=height))
    y1 = int(np.clip(ry - rw / 2, a_min=0., a_max=width))
    y2 = int(np.clip(ry + rw / 2, a_min=0., a_max=width))

    x[:, :, x1:x2, y1:y2] = x[index, :, x1:x2, y1:y2]
    ratio += 1 - (x2 - x1) * (y2 - y1) / (width * height)
    y_a, y_b = y, y[index]

    if use_cuda:
        ratio = torch.tensor(ratio[0], dtype=torch.float32).cuda()
    else:
        ratio = torch.tensor(ratio[0], dtype=torch.float32)

    return x, y_a, y_b, ratio


def adjust_learning_rate(dataset, optimizer, epoch, gammas, schedule, learning_rate):
    """Sets the learning rate to the initial LR decayed by 10 every 30 epochs"""
    lr = learning_rate
    if dataset != 'imagenet':
        assert len(gammas) == len(schedule), "length of gammas and schedule should be equal"
        for (gamma, step) in zip(gammas, schedule):
            if (epoch >= step):
                lr = lr * gamma
            else:
                break
    elif dataset == 'imagenet':
        lr = lr * (0.1**(epoch // 75))
        # if args.epochs == 300:
        #     lr = lr * (0.1**(epoch // 75))
        # else:
        #     lr = lr * (0.1**(epoch // 30))

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
    return lr


class DotDict(dict):
    """
    Example:
    m = Map({'first_name': 'Eduardo'}, last_name='Pool', age=24, sports=['Soccer'])
    """
    def __init__(self, *args, **kwargs):
        super(DotDict, self).__init__(*args, **kwargs)
        for arg in args:
            if isinstance(arg, dict):
                for k, v in arg.items():
                    self[k] = v

        if kwargs:
            for k, v in kwargs.items():
                self[k] = v

    def __getattr__(self, attr):
        return self.get(attr)

    def __setattr__(self, key, value):
        self.__setitem__(key, value)

    def __setitem__(self, key, value):
        super(DotDict, self).__setitem__(key, value)
        self.__dict__.update({key: value})

    def __delattr__(self, item):
        self.__delitem__(item)

    def __delitem__(self, key):
        super(DotDict, self).__delitem__(key)
        del self.__dict__[key]


def save_checkpoint(state, is_best, save):
    filename = os.path.join(save, 'checkpoint.pth.tar')
    torch.save(state, filename)
    if is_best:
        best_filename = os.path.join(save, 'model_best.pth.tar')
        shutil.copyfile(filename, best_filename)


def save_ckpt(model, optimizer, scheduler, epoch, model_path):
    scheduler_state = scheduler.state_dict() if scheduler is not None else None
    torch.save({'model': model.state_dict(),
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler_state}, model_path)


def save_cos_ckpt(model, optimizer, scheduler, epoch, model_path):
    scheduler_state = scheduler.state_dict() if scheduler is not None else None
    torch.save({'model': model.state_dict(),
                'epoch': epoch,
                'optimizer': optimizer.state_dict(),
                'scheduler': scheduler_state,
                'cos_scheduler': scheduler.after_scheduler.state_dict()}, model_path)


def restore_ckpt(model, optimizer, scheduler, model_path, location):
    state = torch.load(model_path, map_location=f'cuda:{location}')
    try:
        model.load_state_dict(state['model'], strict=True)
    except RuntimeError as e:
        try:
            print(e)
            model.module.load_state_dict(state['model'], strict=True)
            return model
        except RuntimeError as e:
            print(e)
            model = WrappedModel(model)  # for loading data parallel wrapped model
            model.load_state_dict(state['model'], strict=True)
            model = model.module
            return model
    optimizer.load_state_dict(state['optimizer'])
    if scheduler is not None:
        scheduler.load_state_dict(state['scheduler'])
    epoch = state['epoch']
    return epoch


def restore_cos_ckpt(model, optimizer, scheduler, model_path, location):
    state = torch.load(model_path, map_location=f'cuda:{location}')
    model.load_state_dict(state['model'], strict=True)
    optimizer.load_state_dict(state['optimizer'])
    if scheduler is not None:
        scheduler.load_state_dict(state['cos_scheduler'])
        # scheduler.after_scheduler.load_state_dict(state['cos_scheduler'])
    epoch = state['epoch']
    return epoch


class WrappedModel(nn.Module):
    def __init__(self, module):
        super(WrappedModel, self).__init__()
        self.module = module  # that actually define.


def load_model(model, model_path, location, strict=True):
    state = torch.load(model_path, map_location=f'cuda:{location}')
    try:
        model.load_state_dict(state['model'], strict=strict)
        return model
    except RuntimeError as e:
        try:
            print(e)
            model.module.load_state_dict(state['model'], strict=strict)
            return model
        except AttributeError as e:
            print(e)
            model = WrappedModel(model)  # for loading data parallel wrapped model
            model.load_state_dict(state['model'], strict=strict)
            model = model.module
            return model


def save(model, model_path):
    torch.save(model.state_dict(), model_path)


def load(model, model_path, location):
    model.load_state_dict(torch.load(model_path, map_location=f'cuda:{location}'), strict=True)


class WarmUpLR(_LRScheduler):
    """warmup_training learning rate scheduler
    Args:
        optimizer: optimzier(e.g. SGD)
        total_iters: totoal_iters of warmup phase
    """

    def __init__(self, optimizer, total_iters, last_epoch=-1):
        self.total_iters = total_iters
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        """we will use the first m batches, and set the learning
        rate to base_lr * m / total_iters
        """
        return [base_lr * self.last_epoch / (self.total_iters + 1e-8) for base_lr in self.base_lrs]

# def load_augmentor(args, augmentor, after_transforms, n_class, sub_policies, aug_model_name):
#     # aug_n_class = n_class if args.aug_n_class == 0 else args.args_n_class
#     # ugly way to get the model name and get the number of class the augmentor predicts
#     # args: dataset, use_cuda, use_parallel, temperature, aug_mode
#     # policy_path, k_ops, save, gpu
#     resize = True if aug_model_name == 'resnet50' else False

#     if args.aug_mode == 'vector':
#         ops, ms, ws = parse_genotype(args.policy_path)
#         augmentor.add_augment_agent(ops, after_transforms, aug_mode='vector',
#                                     search=False, ops_weights=ws, magnitudes=ms, resize=resize)
#         for param in augmentor.parameters():
#             param.requires_grad = False
#     elif args.aug_mode == 'projection' or args.aug_mode == 'cnn':
#         ops = sub_policies
#         augmentor.add_augment_agent(ops, after_transforms, aug_mode=args.aug_mode,
#                                     search=False, delta=args.delta, k_ops=args.k_ops, sampling='prob',
#                                     temperature=args.temperature, save_dir=args.save,
#                                     infer_n_class=n_class, resize=resize, n_proj_layer=args.n_proj_layer)
#         load(augmentor, f'{args.policy_path}/weights.pt', location=args.gpu)
#         for param in augmentor.parameters():
#             param.requires_grad = False
#     elif args.aug_mode == 'randaug':
#         augmentor = getRandAugmentAgent(args.dataset, after_transforms)
#         # logging.info(f'RandAugment with n={augmentor.n} m={augmentor.m}')

#     elif args.aug_mode in ['dada', 'aa', 'fa']:
#         augmentor = getPolicyAugmentAgent(args.dataset, args.aug_mode, after_transforms)
#     elif args.aug_mode == 'pba':
#         augmentor = getPBAAugmentAgent(args.dataset, args.epochs, after_transforms)
#     else:
#         raise ValueError('invalid augmentor name=%s' % args.aug_mode)

#     return augmentor


# class PolicyHistory(object):

#     def __init__(self, op_names, save_dir, n_class):
#         self.op_names = op_names
#         self.save_dir = save_dir
#         self._initialize(n_class)

#     def _initialize(self, n_class):
#         self.history = []
#         # [{m:[], w:[]}, {}]
#         for i in range(n_class):
#             self.history.append({'magnitudes': [],
#                                 'weights': [],
#                                 'var_magnitudes': [],
#                                 'var_weights': []})

#     def add(self, class_idx, m_mu, w_mu, m_std, w_std):
#         if not isinstance(m_mu, list):  # ugly way to bypass batch with single element
#             return
#         self.history[class_idx]['magnitudes'].append(m_mu)
#         self.history[class_idx]['weights'].append(w_mu)
#         self.history[class_idx]['var_magnitudes'].append(m_std)
#         self.history[class_idx]['var_weights'].append(w_std)

#     def save(self, class2label=None):
#         path = os.path.join(self.save_dir, 'policy')
#         vis_path = os.path.join(self.save_dir, 'vis_policy')
#         os.makedirs(path, exist_ok=True)
#         os.makedirs(vis_path, exist_ok=True)
#         header = ','.join(self.op_names)
#         for i, history in enumerate(self.history):
#             k = i if class2label is None else class2label[i]
#             np.savetxt(f'{path}/policy{i}({k})_magnitude.csv',
#                        history['magnitudes'], delimiter=',', header=header, comments='')
#             np.savetxt(f'{path}/policy{i}({k})_weights.csv',
#                        history['weights'], delimiter=',', header=header, comments='')
#             # np.savetxt(f'{vis_path}/policy{i}({k})_var_magnitude.csv',
#             #            history['var_magnitudes'], delimiter=',', header=header, comments='')
#             # np.savetxt(f'{vis_path}/policy{i}({k})_var_weights.csv',
#             #            history['var_weights'], delimiter=',', header=header, comments='')

#     def plot(self):
#         PATH = self.save_dir
#         mag_file_list = glob.glob(f'{PATH}/policy/*_magnitude.csv')
#         weights_file_list = glob.glob(f'{PATH}/policy/*_weights.csv')
#         n_class = len(mag_file_list)

#         f, axes = plt.subplots(n_class, 2, figsize=(15, 5*n_class))

#         for i, file in enumerate(mag_file_list):
#             df = pd.read_csv(file).dropna()
#             x = range(0, len(df))
#             y = df.to_numpy().T
#             axes[i][0].stackplot(x, y, labels=df.columns, edgecolor='none')
#             axes[i][0].set_title(file.split('/')[-1][:-4])

#         for i, file in enumerate(weights_file_list):
#             df = pd.read_csv(file).dropna()
#             x = range(0, len(df))
#             y = df.to_numpy().T
#             axes[i][1].stackplot(x, y, labels=df.columns, edgecolor='none')
#             axes[i][1].set_title(file.split('/')[-1][:-4])

#         axes[-1][-1].legend(loc='upper center', bbox_to_anchor=(-0.1, -0.2), fancybox=True, shadow=True, ncol=10)
#         plt.savefig(f'{PATH}/policy/visual.png')
#         return f


class AvgrageMeter(object):

    def __init__(self):
        self.reset()

    def reset(self):
        self.avg = 0
        self.sum = 0
        self.cnt = 0

    def update(self, val, n=1):
        self.sum += val * n
        self.cnt += n
        self.avg = self.sum / self.cnt


class DataMixer():
    def __init__(self, mix_mode, cam_model=None, scale2target=True, masknet=None, use_hidden=False, stack_all=True):
        self.mix_mode = mix_mode
        self.cam_model = cam_model
        self.scale2target = scale2target
        self.learnable_mask = True if masknet is not None else False
        self.use_hidden = use_hidden
        self.masknet = masknet
        self.adv_noise = False
        self.stack_all = stack_all

    def mixup_data(self, x, y, alpha=1.0, use_cuda=True, img=None, layer_mix=4, return_mask=False, perm=None):
        if alpha > 0:
            lam = np.random.beta(alpha, alpha)
        else:
            lam = 1

        batch_size = x.size()[0]

        if use_cuda:
            index = torch.randperm(batch_size).cuda()
        else:
            index = torch.randperm(batch_size)

        if perm is not None:
            lam = perm[0]
            index = perm[1]

        if 'camix' in self.mix_mode:
            if self.scale2target:  # input cam
                cam_mask = self.cam_model.get_cam(img, y, mixing=False)  # scale to input size
            else:
                cam_mask = self.cam_model.get_cam(img, y, mixing=True)
            cam_mask = cam_mask.squeeze()

            # important mask of img1 add adversarial noise from img2
            if self.adv_noise:
                cam_mask1 = self.cam_model.get_cam(img, y, mixing=False)
                cam_mask2 = self.cam_model.get_cam(img, y, mixing=False, adv=True)
            else:
                cam_mask1 = cam_mask2 = cam_mask

            if self.learnable_mask:
                cam_mask1 = cam_mask1.unsqueeze(1)
                cam_mask2 = cam_mask2.unsqueeze(1)
                # M1 linear scaling lam
                # two_mask = torch.stack([lam*cam_mask1, (1-lam)*cam_mask2[index, :]], dim=0).cuda()

                # M2 pass lam to network
                lam_tensor = torch.tensor(lam).repeat(cam_mask1.shape)
                if self.use_hidden and self.stack_all:
                    feature_set = [lam_tensor.cuda(), cam_mask1.cuda(), cam_mask2[index, :].cuda(), x, x[index, :]]
                else:
                    feature_set = [lam_tensor.cuda(), cam_mask1.cuda(), cam_mask2[index, :].cuda()]

                two_mask = torch.cat(feature_set, dim=1)  # two_mask: orch.Size([32, C, 4, 4])
                # two_mask = two_mask.permute(1,0,2,3)  # torch.Size([32, 2 or 3, 4, 4])
                if self.use_hidden and not self.stack_all:
                    predicted_mask = self.masknet(x, x[index, :], two_mask).squeeze()  # for M3 setting
                else:
                    predicted_mask = self.masknet(two_mask).squeeze()  # torch.Size([32, 4, 4])
                sm_two_mask = torch.stack([predicted_mask, 1-predicted_mask], dim=0)  # torch.Size([2, 32, 4, 4])

            else:
                # Method A
                # A1
                two_mask = torch.stack([lam*cam_mask1, (1-lam)*cam_mask2[index, :]], dim=0)
                # A2
                # two_mask = torch.stack([lam*cam_mask, lam*cam_mask[index, :]], dim=0)
                sm_two_mask = F.softmax(two_mask, dim=0).cuda()
                # A3
                # sm_two_mask = torch.stack([lam*cam_mask, (1-lam)*cam_mask[index, :]], dim=0).cuda()
            if self.scale2target:  # unsqueeze to broadcast cam to RGB images
                mixed_x = torch.mul(sm_two_mask[0].unsqueeze(1), x) + \
                    torch.mul(sm_two_mask[1].unsqueeze(1), x[index, :])
            else:
                mixed_x = torch.mul(sm_two_mask[0], x) + torch.mul(sm_two_mask[1], x[index, :])

            # Method B
            # eps = torch.tensor(1e-5)
            # maska = lam*cam_mask
            # maskb = (1-lam)*cam_mask[index, :]
            # mask_sum = torch.add(maska, maskb)
            # maska = torch.div(maska, mask_sum+eps).cuda()
            # maskb = torch.div(maskb, mask_sum+eps).cuda()
            # mixed_x = torch.mul(maska, x) + torch.mul(maskb, x[index, :])

            # Method C
            # two_mask = torch.stack([cam_mask, cam_mask[index, :]], dim=0)
            # sm_two_mask = F.softmax(two_mask, dim=0).cuda()
            # mixed_x = torch.mul(sm_two_mask[0], lam*x) + torch.mul(sm_two_mask[1], (1-lam)*x[index, :])

            if self.adv_noise:
                y_a = y_b = y
            else:
                y_a, y_b = y, y[index]

            # adj soft label
            # if self.scale2target:
            #     y_lam = torch.sum(sm_two_mask[0], dim=(1,2))/torch.sum(sm_two_mask, dim=(0,2,3))
            # else:
            #     y_lam = torch.sum(sm_two_mask[0], dim=(1,2,3))/torch.sum(sm_two_mask, dim=(0,2,3,4))

            y_lam = lam

        else:
            mixed_x = lam * x + (1 - lam) * x[index, :]
            y_a, y_b = y, y[index]
            y_lam = lam

        if return_mask:
            meta = {'mixed_x': mixed_x,
                    'x_a': x,
                    'x_b': x[index, :],
                    'y_a': y_a,
                    'y_b': y_b,
                    'y_lam': y_lam,
                    'index': index,
                    'cam_mask_a': None,
                    'cam_mask_b': None,
                    'sm_two_mask_a': None,
                    'sm_two_mask_b': None}

            if 'camix' in self.mix_mode:
                meta['cam_mask_a'] = cam_mask1
                meta['cam_mask_b'] = cam_mask2[index, :]
                meta['sm_two_mask_a'] = sm_two_mask[0]
                meta['sm_two_mask_b'] = sm_two_mask[1]
            return meta
        else:
            return mixed_x, y_a, y_b, y_lam


def mixup_data(x, y, alpha=1.0, use_cuda=True):
    '''Returns mixed inputs, pairs of targets, and lambda'''
    if alpha > 0:
        lam = np.random.beta(alpha, alpha)
    else:
        lam = 1

    batch_size = x.size()[0]
    if use_cuda:
        index = torch.randperm(batch_size).cuda()
    else:
        index = torch.randperm(batch_size)

    mixed_x = lam * x + (1 - lam) * x[index, :]
    y_a, y_b = y, y[index]
    return mixed_x, y_a, y_b, lam


def mix_accuracy(output, target_a, target_b, lam, topk=(1,)):
    maxk = max(topk)
    batch_size = target_a.size(0)

    # if lam.shape != torch.Size([]):
    #     lam = lam.mean()

    if type(lam) != float:
        lam = lam.mean()

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct_a = pred.eq(target_a.view(1, -1).expand_as(pred))
    correct_b = pred.eq(target_b.view(1, -1).expand_as(pred))

    res = []

    for k in topk:
        correct_k = lam * correct_a[:k].reshape(-1).float().sum(0) + (1-lam) * correct_b[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))

    return res


def mix_k_accuracy(output, target, index_list, lam_list, topk=(1,)):
    maxk = max(topk)
    batch_size = output.size(0)

    id, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct_list = []
    for index in index_list:
        correct = pred.eq(target[index].view(1, -1).expand_as(pred))
        correct_list.append(correct)

    res = []
    for k in topk:
        correct_k = 0  # torch.tensor(0, dtype=torch.float)
        for correct, lam in zip(correct_list, lam_list):
            correct_k += lam * correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))
    return res


def mix_soft_accuracy(output, target):
    n = output.size(0)
    _, pred = output.topk(2, 1, True, True)
    pred = pred.tolist()
    zeros = torch.zeros(size=output.size()).cuda()
    for i in range(n):
        zeros[i][pred[i]] = 1
    acc = torch.sum(zeros * target) * (100/n)
    return acc


def accuracy(output, target, topk=(1,)):
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []

    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0)
        res.append(correct_k.mul_(100.0/batch_size))
    return res


class Cutout(object):
    def __init__(self, length):
        self.length = length

    def __call__(self, img):
        h, w = img.size(1), img.size(2)
        mask = np.ones((h, w), np.float32)
        y = np.random.randint(h)
        x = np.random.randint(w)

        y1 = np.clip(y - self.length // 2, 0, h)
        y2 = np.clip(y + self.length // 2, 0, h)
        x1 = np.clip(x - self.length // 2, 0, w)
        x2 = np.clip(x + self.length // 2, 0, w)

        mask[y1: y2, x1: x2] = 0.
        mask = torch.from_numpy(mask)
        mask = mask.expand_as(img)
        img *= mask
        return img


def _data_transforms_cifar10(args):
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        # transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    if args.cutout:
        train_transform.transforms.append(Cutout(args.cutout_length))

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])
    return train_transform, valid_transform


def count_parameters_in_MB(model):
    return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name)/1e6


def drop_path(x, drop_prob):
    if drop_prob > 0.:
        keep_prob = 1.-drop_prob
        mask = Variable(torch.cuda.FloatTensor(
            x.size(0), 1, 1, 1).bernoulli_(keep_prob))
        x.div_(keep_prob)
        x.mul_(mask)
    return x


def create_exp_dir(path, scripts_to_save=None):
    if not os.path.exists(path):
        os.mkdir(path)
    print('Experiment dir : {}'.format(path))

    if scripts_to_save is not None:
        os.mkdir(os.path.join(path, 'scripts'))
        for script in scripts_to_save:
            dst_file = os.path.join(path, 'scripts', os.path.basename(script))
            shutil.copyfile(script, dst_file)


def reproducibility(seed):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = True
    torch.autograd.set_detect_anomaly(True)


import torch
import torch.nn as nn

def drop_path(x, drop_prob=0., training=False):
    """Drop paths (Stochastic Depth) per sample (when applied in main path of
    residual blocks).
    """
    if drop_prob == 0. or not training:
        return x
    keep_prob = 1 - drop_prob
    # handle tensors with different dimensions, not just 4D tensors.
    shape = (x.shape[0], ) + (1, ) * (x.ndim - 1)
    random_tensor = keep_prob + torch.rand(
        shape, dtype=x.dtype, device=x.device)
    output = x.div(keep_prob) * random_tensor.floor()
    return output


class DropPath(nn.Module):
    """Drop paths (Stochastic Depth) per sample  (when applied in main path of
    residual blocks).
    Args:
        drop_prob (float): Probability of the path to be zeroed. Default: 0.1
    """

    def __init__(self, drop_prob=0.1):
        super(DropPath, self).__init__()
        self.drop_prob = drop_prob

    def forward(self, x):
        return drop_path(x, self.drop_prob, self.training)
