import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import h5py

k = 5
h = 218
w = 178
poses = []

with open('list_landmarks_align_celeba.txt', 'r') as f:
    line = f.readline()
    line = f.readline()
    line = f.readline().split()

    while line:
        pose = np.zeros((k, 2))

        for i in range(1, len(line) - 1, 2):
            pose[i // 2] = [float(line[i]) / w, float(line[i + 1]) / h]
        poses.append(pose)
        line = f.readline().split()

poses = np.array(poses)  # .reshape((202599, 5, 2))
poses = np.concatenate((poses[:, :, 1:2], poses[:, :, 0:1]), axis=-1)
print("all landmarks read", poses.shape)

mafl_training_index = []
with open('training.txt', 'r') as f:
    line = f.readline()
    while line:
        mafl_training_index.append(int(line.split('.')[0]))
        line = f.readline()
mafl_training_index = np.array(mafl_training_index)
print('MAFL training length', mafl_training_index.shape)

mafl_test_index = []
with open('testing.txt', 'r') as f:
    line = f.readline()
    while line:
        mafl_test_index.append(int(line.split('.')[0]))
        line = f.readline()
mafl_test_index = np.array(mafl_test_index)
print('MAFL training length', mafl_test_index.shape)

data_root = 'img_align_celeba_png/img_align_celeba_png'
celeba_wo_mafl = []
celeba_wo_mafl_label = []
mafl_train = []
mafl_train_label = []
mafl_test = []
mafl_test_label = []
for file_name in os.listdir(data_root):
    if file_name.endswith('.png'):
        img = Image.open('img_align_celeba_png/img_align_celeba_png/' + file_name).resize((128, 128), resample=Image.BILINEAR)
        file_index = int(file_name.split('.')[0])
        if file_index in mafl_training_index:
            mafl_train.append(np.asarray(img).transpose(2, 0, 1))
            mafl_train_label.append(poses[file_index])
        elif file_index in mafl_test_index:
            mafl_test.append(np.asarray(img).transpose(2, 0, 1))
            mafl_test_label.append(poses[file_index])
        else:
            celeba_wo_mafl.append(np.asarray(img).transpose(2, 0, 1))
            celeba_wo_mafl_label.append(poses[file_index])

file = h5py.File('celeba.h5', "w")
file.create_dataset("mafl_train_data", np.shape(np.array(mafl_train)), h5py.h5t.STD_U8BE, data=mafl_train)
file.create_dataset("mafl_train_label", np.shape(np.array(mafl_train_label)), "float32", data=mafl_train_label)
file.create_dataset("mafl_test_data", np.shape(np.array(mafl_test)), h5py.h5t.STD_U8BE, data=mafl_test)
file.create_dataset("mafl_test_label", np.shape(np.array(mafl_test_label)), "float32", data=mafl_test_label)
file.create_dataset("celeba_wo_mafl", np.shape(np.array(celeba_wo_mafl)), h5py.h5t.STD_U8BE, data=celeba_wo_mafl)
file.create_dataset("celeba_wo_mafl_label", np.shape(np.array(celeba_wo_mafl_label)), "float32", data=celeba_wo_mafl_label)
file.close()
