import argparse
import os
from PIL import Image
import numpy as np
import h5py
from scipy import io

parser = argparse.ArgumentParser()
parser.add_argument('--root', type=str, default='data/bbcpose')
parser.add_argument('--new_root', type=str, default='data')
args = parser.parse_args()

if __name__ == '__main__':
    bbc_anno = io.loadmat(os.path.join(args.root, 'bbcpose.mat'))['bbcpose'][0]
    train_data_new_path = os.path.join(args.new_root, 'train_images')
    os.makedirs(train_data_new_path, exist_ok=True)
    index = 0
    for video_idx in range(10):
        for array_idx, frame_idx in enumerate(bbc_anno[video_idx][3][0]):
            img = Image.open(os.path.join(args.root, str(video_idx+1), '{}.jpg'.format(int(frame_idx))))
            kpt = bbc_anno[video_idx][4][:, :, array_idx]
            width, height = img.size
            box_x_1 = kpt[0].min()
            box_x_2 = kpt[0].max()
            box_y_1 = kpt[1].min()
            box_y_2 = kpt[1].max()

            box_x_center = (box_x_1 + box_x_2) / 2
            box_y_center = (box_y_1 + box_y_2) / 2

            bbox_x_min = max(0, box_x_center - 150)
            bbox_x_max = min(width, box_x_center + 150)
            bbox_y_min = max(0, box_y_center - 150)
            bbox_y_max = min(height, box_y_center + 150)

            bbox = (bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max)
            img = img.crop(bbox).resize((256, 256), resample=Image.BILINEAR)

            img.save(os.path.join(train_data_new_path, str(index) + '.jpg'))
            index += 1

    valid_data = []
    valid_pose = []
    for video_idx in range(10, 15):
        for array_idx, frame_idx in enumerate(bbc_anno[video_idx][5][0]):
            img = Image.open(os.path.join(args.root, str(video_idx + 1), '{}.jpg'.format(int(frame_idx))))
            kpt = bbc_anno[video_idx][6][:, :, array_idx]
            width, height = img.size
            box_x_1 = kpt[0].min()
            box_x_2 = kpt[0].max()
            box_y_1 = kpt[1].min()
            box_y_2 = kpt[1].max()

            box_x_center = (box_x_1 + box_x_2) / 2
            box_y_center = (box_y_1 + box_y_2) / 2

            bbox_x_min = max(0, box_x_center - 150)
            bbox_x_max = min(width, box_x_center + 150)
            bbox_y_min = max(0, box_y_center - 150)
            bbox_y_max = min(height, box_y_center + 150)

            bbox = (bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max)
            img = img.crop(bbox).resize((256, 256), resample=Image.BILINEAR)

            bbox_w = bbox_x_max - bbox_x_min
            bbox_h = bbox_y_max - bbox_y_min
            # center coordinate space
            pose = np.concatenate(((kpt[1] - bbox_y_min) / bbox_h,
                                   (kpt[0] - bbox_x_min) / bbox_w)).reshape(2, 7).transpose(1, 0)

            valid_data.append(np.asarray(img).transpose(2, 0, 1))
            valid_pose.append(pose)

    test_data = []
    test_pose = []
    test_scale = []
    for video_idx in range(15, 20):
        for array_idx, frame_idx in enumerate(bbc_anno[video_idx][5][0]):
            img = Image.open(os.path.join(args.root, str(video_idx + 1), '{}.jpg'.format(int(frame_idx))))
            kpt = bbc_anno[video_idx][6][:, :, array_idx]
            width, height = img.size
            box_x_1 = kpt[0].min()
            box_x_2 = kpt[0].max()
            box_y_1 = kpt[1].min()
            box_y_2 = kpt[1].max()

            box_x_center = (box_x_1 + box_x_2) / 2
            box_y_center = (box_y_1 + box_y_2) / 2

            bbox_x_min = max(0, box_x_center - 150)
            bbox_x_max = min(width, box_x_center + 150)
            bbox_y_min = max(0, box_y_center - 150)
            bbox_y_max = min(height, box_y_center + 150)

            bbox = (bbox_x_min, bbox_y_min, bbox_x_max, bbox_y_max)
            img = img.crop(bbox).resize((256, 256), resample=Image.BILINEAR)

            bbox_w = bbox_x_max - bbox_x_min
            bbox_h = bbox_y_max - bbox_y_min
            # center coordinate space
            pose = np.concatenate(((kpt[1] - bbox_y_min) / bbox_h,
                                   (kpt[0] - bbox_x_min) / bbox_w)).reshape(2, 7).transpose(1, 0)

            test_data.append(np.asarray(img).transpose(2, 0, 1))
            test_pose.append(pose)
            test_scale.append(np.array([bbox_h/256.0, bbox_w/256.0]))

    print("valid length:", len(valid_data))
    print("test length:", len(test_data))

    file = h5py.File(os.path.join(args.new_root, 'bbcposeHQ.h5'), "w")
    file.create_dataset('valid_data', np.shape(np.array(valid_data)), h5py.h5t.STD_U8BE, data=valid_data)
    file.create_dataset('valid_pose', np.shape(np.array(valid_pose)), "float32", data=valid_pose)
    file.create_dataset('test_data', np.shape(np.array(test_data)), h5py.h5t.STD_U8BE, data=test_data)
    file.create_dataset('test_pose', np.shape(np.array(test_pose)), "float32", data=test_pose)
    file.create_dataset('test_scale', np.shape(np.array(test_scale)), "float32", data=test_scale)
    file.close()
