from imagenet_util import ImageNetDataset
import pixmix_utils as utils
import torch
import numpy as np
import torchvision.transforms as transforms
import torchvision.datasets as datasets

def pixmix(orig, mixing_pic, k=4, beta=4):
  mixed = orig
  mixings = utils.mixings  
  aug_image_copy = mixing_pic
  for i in range(np.random.randint(k + 1)):
  # for i in range(k):
    mixed = mixings[i % 2](mixed, aug_image_copy, beta)
    # mixed = mixings[1](mixed, aug_image_copy, beta)
    mixed = torch.clip(mixed, 0, 1)
  return mixed

class PixMixDataset(torch.utils.data.Dataset):
  """Dataset wrapper to perform PixMix."""
  def __init__(self, intenstity=1):
    self.dataset = ImageNetDataset(
        '',
    )
    self.k, self.beta = intenstity * 2, 10
    self.mixing_set = datasets.ImageFolder(
    '', 
        transform=transforms.Compose([
            transforms.Resize(256),
            transforms.RandomCrop(224),
            transforms.ToTensor(),
            ])
    )

  def __getitem__(self, i):
    x, y = self.dataset[i]
    rnd_idx = np.random.choice(len(self.mixing_set))
    mixing_pic, _ = self.mixing_set[rnd_idx]
    return pixmix(x, mixing_pic, self.k, self.beta), y

  def __len__(self):
    return len(self.dataset)