
import os
import numpy as np
from PIL import Image


import torch
from torchvision import datasets

####################################################################################

def color_grayscale_arr(arr, red=True):
  """Converts grayscale image to either red or green"""
  assert arr.ndim == 2
  dtype = arr.dtype
  h, w = arr.shape
  arr = np.reshape(arr, [h, w, 1])
  if red:
    arr = np.concatenate([arr,
                          np.zeros((h, w, 2), dtype=dtype)], axis=2)
  else:
    arr = np.concatenate([np.zeros((h, w, 1), dtype=dtype),
                          arr,
                          np.zeros((h, w, 1), dtype=dtype)], axis=2)
  return arr


class ColoredMNIST(datasets.VisionDataset):
  """
  Colored MNIST dataset for testing OOD generalization. Prepared using a similar procedure to https://arxiv.org/pdf/1907.02893.pdf
  We flip the color with a probability that depends on the index.
  """
  def __init__(self, root='./data', env='train', transform=None, target_transform=None, train_len=2000, detection_len=500):
    super(ColoredMNIST, self).__init__(root, transform=transform,
                                target_transform=target_transform)

    self.prepare_colored_mnist(train_len, detection_len)
    self.env = env
    if env in ['train', 'detect', 'test']:
      self.data_label_tuples = torch.load(os.path.join(self.root, 'ColoredMNIST', env) + '.pt')
    else:
      raise RuntimeError(f'{env} env unknown. Valid envs are train, detect, and test.')

  def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data_label_tuples[index]

    if self.transform is not None:
      img = self.transform(img)

    if self.target_transform is not None:
      target = self.target_transform(target)
      
    # If detection set, also return the index
    if self.env == 'detect':
        return img, target, index
    else: # Otherwise, return the image and target
      return img, target
    
  def __len__(self):
    return len(self.data_label_tuples)

  def prepare_colored_mnist(self, train_len, detection_len):
    colored_mnist_dir = os.path.join(self.root, 'ColoredMNIST')
    if os.path.exists(os.path.join(colored_mnist_dir, 'train.pt')) \
        and os.path.exists(os.path.join(colored_mnist_dir, 'detect.pt')) \
        and os.path.exists(os.path.join(colored_mnist_dir, 'test.pt')):
      print('Colored MNIST dataset already exists')
      return

    print('Preparing Colored MNIST')
    train_mnist = datasets.mnist.MNIST(self.root, train=True, download=True)
    
    train_set = []
    detect_set = []
    test_set = []
    
    # Define the training dataset
    train_dataset_end = 40000 # 0 # Index where training set ends 
    training_indices = np.linspace(0,train_dataset_end-1,train_len, dtype=int) # Data used for training
    
    # Define the detection dataset
    indices_all = np.arange(train_dataset_end)
    indices_all_minus_training = np.setdiff1d(indices_all, training_indices)
    # Detection set is linearly spaced indices from this set
    detection_indices = indices_all_minus_training[np.linspace(0,len(indices_all_minus_training)-1,detection_len, dtype=int)]
        
    # Create datasets for training, detection, and testing
    for idx, (im, label) in enumerate(train_mnist): 
            
      # Print progress 
      if idx % 10000 == 0:
        print(f'Converting image {idx}/{len(train_mnist)}')
      im_array = np.array(im)
      
      # If idx is not in training_indices, not in detection indices, and not in test indices, skip
      if idx not in training_indices and idx not in detection_indices and idx < train_dataset_end:
        continue

      # Assign a binary label y to the image based on the digit
      binary_label = 0 if label < 5 else 1

      # Flip label with 25% probability
      if np.random.uniform() < 0.25:
        binary_label = binary_label ^ 1

      # Color the image either red or green according to its possibly flipped label
      color_red = binary_label == 0 

      # Flip the color with a probability that depends on the index: 
      if idx < int(train_dataset_end/2):
          if np.random.uniform() < 0.1:  
            color_red = not color_red
      elif idx < train_dataset_end:
          if np.random.uniform() < 0.4:  
            color_red = not color_red  
      else:
        # 90% in the test environment
        if np.random.uniform() < 0.9:
          color_red = not color_red

      colored_arr = color_grayscale_arr(im_array, red=color_red)

      if idx in detection_indices:
        detect_set.append((Image.fromarray(colored_arr), binary_label))
      elif idx in training_indices:
        train_set.append((Image.fromarray(colored_arr), binary_label))
      else:
        test_set.append((Image.fromarray(colored_arr), binary_label))

    os.makedirs(colored_mnist_dir, exist_ok=True)
    torch.save(train_set, os.path.join(colored_mnist_dir, 'train.pt'))
    torch.save(detect_set, os.path.join(colored_mnist_dir, 'detect.pt'))
    torch.save(test_set, os.path.join(colored_mnist_dir, 'test.pt'))  
####################################################################################