
import os
import numpy as np

import torch
from torchvision import datasets
import torchvision.datasets.utils as dataset_utils

####################################################################################

def sample_example(a_true, p_e):
    # Sample x1
    x1 = np.random.normal(0.0, 2.0) # mean, std
    
    # Define y based on x1
    y = int(x1 > a_true)
    
    # Flip true label with probability 0.25
    if np.random.uniform() < 0.25:
        y = y ^ 1
        
    # Assign x2 based on label
    u = np.random.uniform()
    y_sign = (y*2 - 1) # y_sign \in {-1, 1}
    x2 = y_sign + y_sign*u # x2 is y_sign with a little bit of randomness
    
    # Flip x2 with probability that depends on index; this causes distribution shift 
    if np.random.uniform() < p_e:  
      x2 = -x2
        
    # Convert to torch array
    x = torch.tensor([x1, x2])
    
    return (x, y)  

class SimpleExample(datasets.VisionDataset):
  """
  Simple example for testing OOD generalization. Prepared using a similar procedure to https://arxiv.org/pdf/1907.02893.pdf
    - Two features: x = (x1, x2)
    - Assign binary label {+1,-1} based on true classifier: y = sign(x1 - a)
    - Flip label with probability 0.25
    - Assign x2 based on label: x2 = y + y*u, u ~ [0,1]
    - Flip x2 with probability e \in {e1, e2} [e: environment]

  Args:
    root (string): Root directory of dataset where ``SimpleExample/*.pt`` will exist.
    env (string): Which environment to load. Must be 'train', 'detect', or 'test'.
  """
  def __init__(self, root='./data', env='train', train_len=2000, detect_len=1000, test_len=1000):
    super(SimpleExample, self).__init__(root)

    self.prepare_simple_example(train_len, detect_len, test_len)
    self.env = env
    if env in ['train', 'detect', 'test']:
      self.data_label_tuples = torch.load(os.path.join(self.root, 'SimpleExample', env) + '.pt')
    else:
      raise RuntimeError(f'{env} env unknown. Valid envs are train, detect, and test.')

  def __getitem__(self, index):
    """
    Args:
        index (int): Index

    Returns:
        tuple: (image, target) where target is index of the target class.
    """
    img, target = self.data_label_tuples[index]
      
    # If detection set, also return the index
    if self.env == 'detect':
        return img, target, index
    else: # Otherwise, return the image and target
      return img, target
    
  def __len__(self):
    return len(self.data_label_tuples)
        
  def prepare_simple_example(self, train_len, detect_len, test_len):
    
    ################################################################################################
    simple_example_dir = os.path.join(self.root, 'SimpleExample')
    if os.path.exists(os.path.join(simple_example_dir, 'train.pt')) \
        and os.path.exists(os.path.join(simple_example_dir, 'detect.pt')) \
        and os.path.exists(os.path.join(simple_example_dir, 'test_ood.pt')) \
        and os.path.exists(os.path.join(simple_example_dir, 'test_in_dist.pt')):
      print('SimpleExample dataset already exists')
      return

    print('Preparing SimpleExample data.')
    ################################################################################################
    
    ################################################################################################
    # True classification boundary: y = sign(x1 > a_true)
    a_true = 0.0
    
    p_e1 = 0.0 # 0.1 Probability of flipping x2 at beginning of training examples
    p_e2 = 0.3 # 0.4 Probability of flipping x2 at end of training examples
    p_test = 0.9 # Probability of flipping x2 corresponding to test examples
    ################################################################################################
    
    ################################################################################################
    # Initialize different sets
    
    train_set = [None]*train_len
    detect_set = [None]*detect_len
    test_set = [None]*test_len
    
    ################################################################################################
    # Create the training dataset
    
    for idx in range(train_len):
        
        # # Flip x2 with probability that depends on index; this causes distribution shift
        # if idx < int(train_len/2):
        #   x, y = sample_example(a_true, p_e1)
        # else:
        #   x, y = sample_example(a_true, p_e2)
        
        p_t = (p_e2-p_e1)*(idx/(train_len-1)) + p_e1 # Probability of flipping x2 that increases with index
        x, y = sample_example(a_true, p_t)
            
        # Add this to training dataset
        train_set[idx] = (x, y)  
    ################################################################################################
    
    ################################################################################################
    # Create the detection dataset
    
    for idx in range(detect_len):
        
        p_t = (p_e2-p_e1)*(idx/(detect_len-1)) + p_e1 # Probability of flipping x2 that increases with index
        x, y = sample_example(a_true, p_t)
            
        # Add this to detection dataset
        detect_set[idx] = (x, y)  
    ################################################################################################
    
    ################################################################################################
    # Create the OOD test dataset
    for idx in range(test_len):
      x, y = sample_example(a_true, p_test)

      # Add this to test dataset
      test_set[idx] = (x, y)  
    ################################################################################################
    

    os.makedirs(simple_example_dir, exist_ok=True)
    torch.save(train_set, os.path.join(simple_example_dir, 'train.pt'))
    torch.save(detect_set, os.path.join(simple_example_dir, 'detect.pt'))
    torch.save(test_set, os.path.join(simple_example_dir, 'test.pt'))  
####################################################################################