from PIL import Image
import numpy as np
import sys

dataset_path = sys.argv[1]
train = int(sys.argv[2])
replicates = int(sys.argv[3])
category = int(sys.argv[4])

dataset = []
labels = []
for imid in range(category, category+1):
  if category < 21:
    for view in range(72):
      im_frame = Image.open('dataset/coil-20-proc/obj' + str(imid) + '__' + str(view) + '.png')
      np_frame = np.array(im_frame.getdata())
      np_frame = np_frame.flatten()[None,:]/255.
      if train == 0:
        dataset.append(np_frame)
      else:
        dataset.append(np_frame + np.random.uniform(low=0., high=.1, size=np_frame.shape))
      labels.append(view*2*np.pi/72.) if view <= 35 else labels.append(2*np.pi - view*2*np.pi/72.) 
      for rep in range(replicates):
        dataset.append(np_frame + np.random.uniform(low=0., high=.1, size=np_frame.shape))
        labels.append(view*2*np.pi/72.) if view <= 35 else labels.append(2*np.pi - view*2*np.pi/72.) 
  else:
    for view in range(72):
      im_frame = Image.open('dataset/coil-20-unproc/obj' + str(imid-20) + '__' + str(view) + '.png')
      np_frame = np.array(im_frame.getdata())
      np_frame = np_frame.flatten()[None,:]/255.
      if train == 0:
        dataset.append(np_frame)
      else:
        dataset.append(np_frame + np.random.uniform(low=0., high=.1, size=np_frame.shape))
      labels.append(view*2*np.pi/72.) if view <= 35 else labels.append(2*np.pi - view*2*np.pi/72.) 
      for rep in range(replicates):
        dataset.append(np_frame + np.random.uniform(low=0., high=.1, size=np_frame.shape))
        labels.append(view*2*np.pi/72.) if view <= 35 else labels.append(2*np.pi - view*2*np.pi/72.) 

labels = np.array(labels)
dataset = np.vstack(dataset)
np.save(dataset_path, dataset)
np.save(dataset_path[:-4] + '_labels.npy', labels)
