import os
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.utils.data as data_utils
from torch.utils.data.distributed import DistributedSampler
import numpy as np

def load_static_mnist(data_dir):
  def lines_to_np_array(lines):
      return np.array([[int(i) for i in line.split()] for line in lines])
  with open(os.path.join(data_dir, 'MNIST_static', 'binarized_mnist_train.amat')) as f:
      lines = f.readlines()
  x_train = lines_to_np_array(lines).astype('float32').reshape(-1, 1, 28, 28)
  with open(os.path.join(data_dir, 'MNIST_static', 'binarized_mnist_valid.amat')) as f:
      lines = f.readlines()
  x_val = lines_to_np_array(lines).astype('float32').reshape(-1, 1, 28, 28)
  with open(os.path.join(data_dir, 'MNIST_static', 'binarized_mnist_test.amat')) as f:
      lines = f.readlines()
  x_test = lines_to_np_array(lines).astype('float32').reshape(-1, 1, 28, 28)
  # shuffle train data
  np.random.shuffle(x_train)
  # idle y's
  y_train = np.zeros( (x_train.shape[0], 1) )
  y_val = np.zeros( (x_val.shape[0], 1) )
  y_test = np.zeros( (x_test.shape[0], 1) )
  # convert data from {0, 1} to {-1, 1}
  x_train = 2 * x_train - 1
  x_val = 2 * x_val - 1
  x_test = 2 * x_test - 1
  # pytorch data loader
  train_dataset = data_utils.TensorDataset(torch.from_numpy(x_train), torch.from_numpy(y_train))
  val_dataset = data_utils.TensorDataset(torch.from_numpy(x_val), torch.from_numpy(y_val))
  test_dataset = data_utils.TensorDataset(torch.from_numpy(x_test), torch.from_numpy(y_test))
  return train_dataset, val_dataset, test_dataset

def get_dataset(config, distributed=False):
  kwargs = {'num_workers': 1, 'pin_memory': True, 'drop_last': True}

  image_augmented_transforms = transforms.Compose([
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomVerticalFlip(p=0.5),
    transforms.RandomApply([transforms.RandomRotation((90, 90))], p=0.5),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
  ])

  image_transforms = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)),
  ])

  image_augmented_transforms_greyscale = transforms.Compose([
    transforms.RandomApply([transforms.RandomRotation((90, 90))], p=0.5),
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5), std=(0.5)),
  ])

  image_transforms_greyscale = transforms.Compose([
    transforms.ToTensor(), 
    transforms.Normalize(mean=(0.5), std=(0.5)),
  ])

  if config.dataset == 'CIFAR10':
    train_dataset = datasets.CIFAR10(config.data_dir, train=True, download=True, transform=image_augmented_transforms)
    test_dataset = datasets.CIFAR10(config.data_dir, train=False, download=True, transform=image_transforms)

  elif config.dataset == 'MNIST':
    train_dataset = datasets.MNIST(config.data_dir, train=True, download=True, transform=image_transforms_greyscale)
    test_dataset = datasets.MNIST(config.data_dir, train=False, download=True, transform=image_transforms_greyscale)

  elif config.dataset == 'MNIST_bin':
    train_dataset, _, test_dataset = load_static_mnist(config.data_dir)

  elif config.dataset == 'IMAGENET32':
    from .dataset_imagenet import ImageNetDownSample
    train_dataset = ImageNetDownSample(root=config.data_dir, train=True, transform=image_transforms)
    test_dataset = ImageNetDownSample(root=config.data_dir, train=False, transform=image_transforms)

  elif config.dataset == 'TEXT8':
    from .dataset_text8 import Text8
    data = Text8(root=config.data_dir, seq_len=config.seqlen)
    data_shape = (1,config.seqlen)
    num_classes = 27
    #train_dataset = torch.utils.data.ConcatDataset([data.train, data.valid])
    train_dataset = data.train
    test_dataset = data.test

  elif config.dataset == 'ZINC':
    from .dataset_mol import load_zinc
    train_dataset, test_dataset, _ = load_zinc(config)
  elif config.dataset == 'MOSES':
    from .dataset_mol import load_moses
    train_dataset, test_dataset, _ = load_moses(config)

  if distributed:  
    train_sampler = DistributedSampler(train_dataset, num_replicas=config.world_size, rank=config.local_rank, shuffle=True, drop_last=False)
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, sampler=train_sampler, **kwargs)
  else:
    train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True, num_workers=4)
  test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=config.batch_size, shuffle=False, **kwargs)

  return train_loader, test_loader