import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import os
import h5py

k = 5
h = 218
w = 178
poses = []

with open('../../data/celeba_wild/list_landmarks_celeba.txt', 'r') as f:
    line = f.readline()
    line = f.readline()
    line = f.readline().split()

    while line:
        pose = np.zeros((k, 2))

        for i in range(1, len(line) - 1, 2):
            pose[i // 2] = [float(line[i]), float(line[i + 1])]
        poses.append(pose)
        line = f.readline().split()

poses = np.array(poses)  # .reshape((202599, 5, 2))
poses = np.concatenate((poses[:, :, 1:2], poses[:, :, 0:1]), axis=-1)
print("all landmarks read", poses.shape)

data = np.zeros((202599, 3, 128, 128), dtype=np.uint8)
label = np.zeros((202599, 5, 2), dtype=np.float32)
for file_name in os.listdir('../../data/celeba_wild/img_celeba'):
    if file_name.endswith('.jpg'):
        img = Image.open('../../data/celeba_wild/img_celeba/' + file_name)
        file_index = int(file_name.split('.')[0]) - 1
        data[file_index] = np.asarray(img.resize((128, 128), resample=Image.BILINEAR), dtype=np.uint8).transpose(2, 0, 1)
        w, h = img.size
        label[file_index, :, 0] = poses[file_index, :, 0] / h * 2 - 1
        label[file_index, :, 1] = poses[file_index, :, 1] / w * 2 - 1
        # plt.imshow(data[file_index].transpose(1, 2, 0).astype(np.uint8))
        # plt.scatter(label[file_index, :, 1]*64+64, label[file_index, :, 0]*64+64)
        # plt.show()

train_data = data[:162770]
train_pose = poses[:162770]
valid_data = data[162770: 182637]
valid_pose = poses[162770: 182637]
test_data = data[182637:]
test_pose = poses[182637:]

file = h5py.File('celeba_wild.h5', "w")
file.create_dataset("train_data", np.shape(train_data), h5py.h5t.STD_U8BE, data=train_data)
file.create_dataset("train_label", np.shape(train_pose), "float32", data=train_pose)
file.create_dataset("valid_data", np.shape(valid_data), h5py.h5t.STD_U8BE, data=valid_data)
file.create_dataset("valid_label", np.shape(valid_pose), "float32", data=valid_pose)
file.create_dataset("test_data", np.shape(test_data), h5py.h5t.STD_U8BE, data=test_data)
file.create_dataset("test_label", np.shape(test_pose), "float32", data=test_pose)
file.close()
