# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Main script to launch PixMix training on MNIST/CIFAR-10/100.

Supports WideResNet, ResNeXt models on CIFAR-10 and CIFAR-100 as well
as evaluation on CIFAR-10-C and CIFAR-100-C.

Example usage:
  `python pixmix.py`
"""
from __future__ import print_function

import argparse
import os
import shutil
import time

import pixmix_utils as utils
import numpy as np
import matplotlib.pyplot as plt

import torch
# import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.nn.functional as F
from torchvision import datasets
from torchvision import transforms
from torchvision.utils import save_image

parser = argparse.ArgumentParser(
    description='Trains a CIFAR Classifier',
    formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
    '--dataset',
    type=str,
    default='mnist',
    choices=['mnist', 'cifar10', 'cifar100', 'AwA2'],
    help='Choose between MNIST, CIFAR-10, CIFAR-100.')
parser.add_argument(
    '--img-dim',
    type=int,
    default=32,
    help='dimension of images')
parser.add_argument(
    '--data-path',
    type=str,
    default='/nobackup/jihye/data',
    #required=True,
    help='Path to CIFAR and CIFAR-C directories')
parser.add_argument(
    '--mixing-set',
    type=str,
    # required=True,
    default='/nobackup/jihye/data/fractals_and_fvis/fractals',
    help='Mixing set directory.')
parser.add_argument(
    '--save-path',
    type=str,
    default='/nobackup/jihye/data/mnist-fractals',
    help='Path to save pixmix results'
)
parser.add_argument(
    '--use_300k',
    action='store_true',
    help='use 300K random images as aug data'
)

# PixMix options
parser.add_argument(
    '--beta',
    default=3,
    type=int,
    help='Severity of mixing')
parser.add_argument(
    '--k',
    default=4,
    type=int,
    help='Mixing iterations')
parser.add_argument(
    '--aug-severity',
    default=3,
    type=int,
    help='Severity of base augmentation operators')
parser.add_argument(
    '--all-ops',
    '-all',
    action='store_true',
    help='Turn on all augmentation operations (+brightness,contrast,color,sharpness).')


args = parser.parse_args()
print(args)


NUM_CLASSES = 100 if args.dataset == 'cifar100' else 10


def pixmix(orig, mixing_pic, preprocess):

  mixings = utils.mixings
  tensorize, normalize = preprocess['tensorize'], preprocess['normalize']
  if np.random.random() < 0.5:
    mixed = tensorize(augment_input(orig))
  else:
    mixed = tensorize(orig)

  for _ in range(np.random.randint(args.k + 1)):

    if np.random.random() < 0.5:
      aug_image_copy = tensorize(augment_input(orig))
    else:
      aug_image_copy = tensorize(mixing_pic)

    mixed_op = np.random.choice(mixings)
    mixed = mixed_op(mixed, aug_image_copy, args.beta)
    mixed = torch.clip(mixed, 0, 1)

  return normalize(mixed)

def augment_input(image):
  aug_list = utils.augmentations_all if args.all_ops else utils.augmentations
  op = np.random.choice(aug_list)
  return op(image.copy(), args.aug_severity)

class RandomImages300K(torch.utils.data.Dataset):
    def __init__(self, file, transform):
        self.dataset = np.load(file)
        self.transform = transform

    def __getitem__(self, index):
        img = self.dataset[index]
        return self.transform(img), 0

    def __len__(self):
        return len(self.dataset)

class PixMixDataset(torch.utils.data.Dataset):
  """Dataset wrapper to perform PixMix."""

  def __init__(self, dataset, mixing_set, preprocess):
    self.dataset = dataset
    self.mixing_set = mixing_set
    self.preprocess = preprocess

  def __getitem__(self, i):
    x, y = self.dataset[i]
    rnd_idx = np.random.choice(len(self.mixing_set))
    mixing_pic, _ = self.mixing_set[rnd_idx]
    return pixmix(x, mixing_pic, self.preprocess), y

  def __len__(self):
    return len(self.dataset)


def main():
  torch.manual_seed(1)
  np.random.seed(1)

  # Load datasets
  RESIZE_DIM = 250 if args.dataset == 'AwA2' else 36
  train_transform = transforms.Compose(
      [transforms.Resize(args.img_dim),
       transforms.RandomHorizontalFlip(),
       transforms.RandomCrop(args.img_dim, padding=4)])
  mixing_set_transform = transforms.Compose(
      [transforms.Resize(RESIZE_DIM),
       transforms.RandomCrop(args.img_dim)])
  if args.dataset == 'mnist':
      mixing_set_transform = transforms.Compose(
      [transforms.Grayscale(),
       transforms.Resize(RESIZE_DIM),
       transforms.RandomCrop(args.img_dim)])
  else:
      mixing_set_transform = transforms.Compose(
      [transforms.Resize(RESIZE_DIM),
       transforms.RandomCrop(args.img_dim)])

  to_tensor = transforms.ToTensor()
  normalize = transforms.Normalize(0.5, 0.5) if args.dataset == 'mnist' else transforms.Normalize([0.5] * 3, [0.5] * 3)
  #test_transform = transforms.Compose(
  #    [transforms.ToTensor(), normalize])


  if args.dataset == 'mnist':
    train_data = datasets.MNIST(
        os.path.join(args.data_path, 'mnist'), train=True, transform=train_transform, download=True)
  elif args.dataset == 'cifar10':
    train_data = datasets.CIFAR10(
        os.path.join(args.data_path, 'cifar'), train=True, transform=train_transform, download=True)
  elif args.dataset == 'cifar100':
    train_data = datasets.CIFAR100(
        os.path.join(args.data_path, 'cifar'), train=True, transform=train_transform, download=True)
  else:
    train_data = datasets.ImageFolder(args.data_path, transform=train_transform)

  if args.use_300k:
    mixing_set = RandomImages300K(file='300K_random_images.npy', transform=transforms.Compose(
      [transforms.ToTensor(), transforms.ToPILImage(), transforms.RandomCrop(args.img_dim, padding=4),
      transforms.RandomHorizontalFlip()]))
  else:
    mixing_set = datasets.ImageFolder(args.mixing_set, transform=mixing_set_transform)
  print('train_size', len(train_data))
  print('aug_size', len(mixing_set))


  if not os.path.exists(args.save_path):
    os.makedirs(args.save_path)
  if not os.path.isdir(args.save_path):
    raise Exception('%s is not a dir' % args.save_path)

  # log_path = os.path.join(args.save,
  #                         args.dataset + '_' + args.model + '_training_log.csv')
  # with open(log_path, 'w') as f:
  #   f.write('epoch,time(s),train_loss,test_loss,test_error(%)\n')


  # with open(log_path, 'a') as f:
  #   f.write('%03d,%05d,%0.6f,%0.5f,%0.2f\n' % (
  #       (epoch + 1),
  #       time.time() - begin_time,
  #       train_loss_ema,
  #       test_loss,
  #       100 - 100. * test_acc,
  #   ))



  # check some original images
  for i in range(10):
    img, label = train_data[i]
    img_, label_ = mixing_set[i]
    img.save(os.path.join(args.save_path,f'{i}_orig.png'))

  # Fix dataloader worker issue
  # https://github.com/pytorch/pytorch/issues/5059
  def wif(id):
    uint64_seed = torch.initial_seed()
    ss = np.random.SeedSequence([uint64_seed])
    # More than 128 bits (4 32-bit words) would be overkill.
    np.random.seed(ss.generate_state(4))

  train_data = PixMixDataset(train_data, mixing_set, {'normalize': normalize, 'tensorize': to_tensor})

  train_loader = torch.utils.data.DataLoader(
      train_data,
      batch_size=1,
      shuffle=False,
      num_workers=10,
      pin_memory=True,
      worker_init_fn=wif)


  # Run PixMix
  labels = np.array([])
  for i, (image, target) in enumerate(train_loader):
    if i % 1000 == 0:
        print(f'saving {i}th image....')
    save_image(image, os.path.join(args.save_path,f'{i}.png'))
    labels = np.concatenate((labels, target.numpy()))
  
  #np.save(os.path.join(args.save_path,'labels.npy'), labels)




if __name__ == '__main__':
  main()
