import numpy as np
import torch
import torchvision.transforms as transforms
import utils.augmix_ops as augmentations

# AugMix Transforms
def get_preaugment():
    return transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
        ])

def augmix(image, preprocess, aug_list, severity=1):
    preaugment = get_preaugment()   # Resizing with scaling and ratio
    x_orig = preaugment(image)
    x_processed = preprocess(x_orig)
    if len(aug_list) == 0:
        return x_processed
    w = np.float32(np.random.dirichlet([1.0, 1.0, 1.0]))
    m = np.float32(np.random.beta(1.0, 1.0))

    mix = torch.zeros_like(x_processed)
    for i in range(3):
        # x_aug = x_orig.copy()
        x_aug = transforms.ToPILImage()(x_orig[0]).copy()
        for _ in range(np.random.randint(1, 4)):
            x_aug = np.random.choice(aug_list)(x_aug, severity)
        # mix += w[i] * preprocess(x_aug)
        x_aug = transforms.ToTensor()(x_aug).unsqueeze(dim=0)
        mix += w[i] * preprocess(x_aug.to(x_processed.device))
    mix = m * x_processed + (1 - m) * mix
    return mix

class AugMixAugmenter(object):
    def __init__(self, base_transform, preprocess, n_views=2, augmix=False,
                 severity=1):
        self.base_transform = base_transform
        self.preprocess = preprocess
        self.n_views = n_views
        if augmix:
            self.aug_list = augmentations.augmentations
        else:
            self.aug_list = []
        self.severity = severity

    def __call__(self, x):
        if self.base_transform is not None:
            image = self.preprocess(self.base_transform(x))
        else:
            image = self.preprocess(x)

        views = [augmix(x, self.preprocess, self.aug_list, self.severity) for _ in range(self.n_views)]
        return [image] + views