"""
Script from
https://github.com/zllrunning/face-parsing.PyTorch/blob/master/face_dataset.py
Fusion the initials mask into one image
"""
import os.path as osp
import os
from PIL import Image
import numpy as np
import cv2


if __name__ == "__main__":
    src_face   = 'datasets/CelebAMask-HQ/CelebA-HQ-img/'
    src_mask   = 'datasets/CelebAMask-HQ/CelebAMask-HQ-mask-anno/'
    src_sketch = 'datasets/CelebAMask-HQ/CelebA-HQ-sketch/'

    dest_mask   = 'datasets/CelebAMask-HQ/CelebAMask-HQ-mask-anno_One/'

    counter = 0
    total = 0
    for i in range(15):  # for each folder in the mask folder
        # files = os.listdir(osp.join(face_sep_mask, str(i)))

        atts = ['skin', 'l_brow', 'r_brow', 'l_eye', 'r_eye', 'eye_g', 'l_ear', 'r_ear', 'ear_r',
                'nose', 'mouth', 'u_lip', 'l_lip', 'neck', 'neck_l', 'cloth', 'hair', 'hat']

        for j in range(i*2000, (i+1)*2000):  # the index in the current folder
            mask = np.zeros((512, 512))

            for attr_index, att in enumerate(atts, 1):
                total += 1
                file_name = ''.join([str(j).rjust(5, '0'), '_', att, '.png'])
                path = osp.join(src_mask, str(i), file_name)

                if os.path.exists(path):
                    counter += 1
                    sep_mask = np.array(Image.open(path).convert('P'))
                    # print(np.unique(sep_mask))

                    mask[sep_mask == 225] = attr_index
            cv2.imwrite('{}/{}.png'.format(dest_mask, j), mask)
            print(j)

    print(counter, total)
