import os
import numpy as np
import random
import pandas as pd

from PIL import Image
from io import BytesIO

"""
1. 5 generic categories: four-legged animals, human figures, airplanes, trucks, and cars. 
   [0, 1, 2, 3, 4]
2. 5 instances
   train: [4, 6, 7, 8, 9]
   test : [0, 1, 2, 3, 5]
3. 9 elevations (30 to 70 degrees every 5 degrees)
   [0, 1, 2, 3, 4, 5, 6, 7, 8]
4. 18 azimuths (0 to 340 every 20 degrees). 
   [0, 2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26, 28, 30, 32, 34]
5. 6 lighting conditions
   [0, 1, 2, 3, 4, 5]

The training set is composed of 5 instances of each category (instances 4, 6, 7, 8 and 9), 
and the test set of the remaining 5 instances (instances 0, 1, 2, 3, and 5).
"""

train = pd.read_parquet('train-00000-of-00001-ba54590c34eb8af1.parquet')
test  = pd.read_parquet('test-00000-of-00001-b4af1727fb5b132e.parquet')

# TRAIN
train_img_lt = train['image_lt']
train_img_rt = train['image_rt']

train_latent_value = train[['category', 'instance', 'elevation', 'azimuth','lighting']]

train_imgs = pd.concat([train_img_lt, train_img_rt], ignore_index=True)
train_latents_values = pd.concat([train_latent_value, train_latent_value], ignore_index=True)

# TEST
test_img_lt = test['image_lt']
test_img_rt = test['image_rt']

test_latent_value = test[['category', 'instance', 'elevation', 'azimuth','lighting']]

test_imgs = pd.concat([test_img_lt, test_img_rt], ignore_index=True)
test_latents_values = pd.concat([test_latent_value, test_latent_value], ignore_index=True)

for i, binary_data in enumerate(train_imgs):
    image_file = BytesIO(binary_data['bytes'])
    train_imgs[i] =  np.asarray(Image.open(image_file))/255.

for i, binary_data in enumerate(test_imgs):
    image_file = BytesIO(binary_data['bytes'])
    test_imgs[i] =  np.asarray(Image.open(image_file))/255.

print(test_imgs)

iid_test   = np.array([])
label_test = []

for category in ['animal', 'human', 'airplane', 'truck', 'car']:
    if category == 'animal':  
        _category = 0 
    elif category == 'human':
        _category = 1
    elif category == 'airplane': 
        _category = 2
    elif category == 'truck': 
        _category = 3
    elif category == 'car': 
        _category = 4

    for elevation in [0, 2, 4, 6, 8]:

        for azimuth in [0, 8, 16, 24, 32]:

            for lighting in [0, 1, 2, 3, 4]:
        
                img = train_imgs[(train_latents_values['category'] == _category) & (train_latents_values['elevation'] == elevation) & (train_latents_values['azimuth'] == azimuth) & (train_latents_values['lighting'] == lighting)]
                _test_img = test_imgs[(train_latents_values['category'] == _category) & (train_latents_values['elevation'] == elevation) & (train_latents_values['azimuth'] == azimuth) & (train_latents_values['lighting'] == lighting)]
                test_img = _test_img.sample(frac = 1)
                test_img = test_img.head(10)
  
                np.save(f'./smallnorb_split/{category}_{elevation}_{azimuth}_{lighting}.npy', img)
                if len(iid_test) == 0:
                    iid_test = test_img
                else:
                    iid_test = np.append(iid_test, test_img, axis = 0)
                tmp = [category] * 10
                label_test.append(tmp)

label_test = np.reshape(label_test, -1)
print(f"[DEBUG] test img: {len(iid_test)}")
print(f"[DEBUG] test label: {len(label_test)}")
np.save(f'./smallnorb_split/iid_test.npy', iid_test)
np.save(f'./smallnorb_split/label_test.npy', label_test)

