import os
import cv2
import numpy as np
import sys

def process_wflw(anno, target_size):
    image_name = anno[-1]
    image_path = os.path.join('..', 'data', 'WFLW', 'WFLW_images', image_name)
    image = cv2.imread(image_path)
    image_height, image_width, _ = image.shape
    lms = anno[:196]
    lms = [float(x) for x in lms]
    lms_x = lms[0::2]
    lms_y = lms[1::2]
    lms_x = [x if x >=0 else 0 for x in lms_x] 
    lms_x = [x if x <=image_width else image_width for x in lms_x] 
    lms_y = [y if y >=0 else 0 for y in lms_y] 
    lms_y = [y if y <=image_height else image_height for y in lms_y] 
    lms = [[x,y] for x,y in zip(lms_x, lms_y)]
    lms = [x for z in lms for x in z]
    bbox = anno[196:200]
    bbox = [float(x) for x in bbox]
    attrs = anno[200:206]
    attrs = np.array([int(x) for x in attrs])
    bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax = bbox

    width = bbox_xmax - bbox_xmin
    height = bbox_ymax - bbox_ymin
    scale = 1.6
    bbox_xmin -= width * (scale-1)/2
    bbox_ymin -= height * (scale-1)/2
    bbox_xmax += width * (scale-1)/2
    bbox_ymax += height * (scale-1)/2
    bbox_xmin = max(bbox_xmin, 0)
    bbox_ymin = max(bbox_ymin, 0)
    bbox_xmax = min(bbox_xmax, image_width-1)
    bbox_ymax = min(bbox_ymax, image_height-1)
    width = bbox_xmax - bbox_xmin
    height = bbox_ymax - bbox_ymin
    image_crop = image[int(bbox_ymin):int(bbox_ymax), int(bbox_xmin):int(bbox_xmax), :]
    image_crop = cv2.resize(image_crop, (target_size, target_size))

    tmp1 = [bbox_xmin, bbox_ymin]*98
    tmp1 = np.array(tmp1)
    tmp2 = [width, height]*98
    tmp2 = np.array(tmp2)
    lms = np.array(lms) - tmp1
    lms = lms / tmp2
    lms = lms.tolist()
    lms = zip(lms[0::2], lms[1::2])
    return image_crop, list(lms) 


def process_FLSC(anno, target_size):
    image_name = anno[-1]
    image_path = os.path.join('..', 'data', 'FLSC', 'FLSC_images', image_name)
    image = cv2.imread(image_path)
    image_height, image_width, _ = image.shape
    lms = anno[:196]
    lms = [float(x) for x in lms]
    lms_x = lms[0::2]
    lms_y = lms[1::2]
    lms_x = [x if x >=0 else 0 for x in lms_x] 
    lms_x = [x if x <=image_width else image_width for x in lms_x] 
    lms_y = [y if y >=0 else 0 for y in lms_y] 
    lms_y = [y if y <=image_height else image_height for y in lms_y] 
    lms = [[x,y] for x,y in zip(lms_x, lms_y)]
    lms = [x for z in lms for x in z]
    
    bbox = anno[196:200]
    bbox = [float(x) for x in bbox]

    bbox_xmin, bbox_ymin, bbox_xmax, bbox_ymax = bbox
    width = bbox_xmax - bbox_xmin
    height = bbox_ymax - bbox_ymin

    scale = 1.6
    bbox_xmin -= width * (scale-1)/2
    bbox_ymin -= height * (scale-1)/2
    bbox_xmax += width * (scale-1)/2
    bbox_ymax += height * (scale-1)/2
    bbox_xmin = max(bbox_xmin, 0)
    bbox_ymin = max(bbox_ymin, 0)
    bbox_xmax = min(bbox_xmax, image_width-1)
    bbox_ymax = min(bbox_ymax, image_height-1)
    width = bbox_xmax - bbox_xmin
    height = bbox_ymax - bbox_ymin
    image_crop = image[int(bbox_ymin):int(bbox_ymax), int(bbox_xmin):int(bbox_xmax), :]
    image_crop = cv2.resize(image_crop, (target_size, target_size))

    tmp1 = [bbox_xmin, bbox_ymin]*98
    tmp1 = np.array(tmp1)
    tmp2 = [width, height]*98
    tmp2 = np.array(tmp2)
    lms = np.array(lms) - tmp1
    lms = lms / tmp2
    lms = lms.tolist()
    lms = zip(lms[0::2], lms[1::2])
    
    return image_crop, list(lms) 


def gen_meanface(root_folder, data_name):
    with open(os.path.join(root_folder, data_name, 'train.txt'), 'r') as f:
        annos = f.readlines()
    annos = [x.strip().split()[1:] for x in annos]
    for i in range(len(annos)):
      if '(' in annos[i][0]:
        annos[i] = annos[i][1:]

    annos = [[float(x) for x in anno] for anno in annos]
    annos = np.array(annos)
    meanface = np.mean(annos, axis=0)
    meanface = meanface.tolist()
    meanface = [str(x) for x in meanface]
    if len(meanface) % 2 != 0:
        meanface = meanface[:-1]

    with open(os.path.join(root_folder, data_name, 'meanface.txt'), 'w') as f:
        f.write(' '.join(meanface))


def gen_data(root_folder, data_name, target_size):
    if not os.path.exists(os.path.join(root_folder, data_name, 'images_train')):
        os.mkdir(os.path.join(root_folder, data_name, 'images_train'))
    if not os.path.exists(os.path.join(root_folder, data_name, 'images_test')):
        os.mkdir(os.path.join(root_folder, data_name, 'images_test'))

    ################################################################################################################
    if data_name == 'WFLW':
        train_file = 'list_98pt_rect_attr_train.txt'
        with open(os.path.join(root_folder, 'WFLW', 'WFLW_annotations', 'list_98pt_rect_attr_train_test', train_file), 'r') as f:
            annos_train = f.readlines()
        annos_train = [x.strip().split() for x in annos_train]
        count = 1
        with open(os.path.join(root_folder, 'WFLW', 'train.txt'), 'w') as f:
            for anno_train in annos_train:
                image_crop, anno = process_wflw(anno_train, target_size)
                pad_num = 4-len(str(count))
                image_crop_name = 'wflw_train_' + '0' * pad_num + str(count) + '.jpg'
                print(image_crop_name)
                cv2.imwrite(os.path.join(root_folder, 'WFLW', 'images_train', image_crop_name), image_crop)
                f.write(image_crop_name+' ')
                for x,y in anno:
                    f.write(str(x)+' '+str(y)+' ')
                f.write('\n')
                count += 1

        test_file = 'list_98pt_rect_attr_test.txt'
        with open(os.path.join(root_folder, 'WFLW', 'WFLW_annotations', 'list_98pt_rect_attr_train_test', test_file), 'r') as f:
            annos_test = f.readlines()
        annos_test = [x.strip().split() for x in annos_test]
        names_mapping = {}
        count = 1
        with open(os.path.join(root_folder, 'WFLW', 'test.txt'), 'w') as f:
            for anno_test in annos_test:
                image_crop, anno = process_wflw(anno_test, target_size)
                pad_num = 4-len(str(count))
                image_crop_name = 'wflw_test_' + '0' * pad_num + str(count) + '.jpg'
                print(image_crop_name)
                names_mapping[anno_test[0]+'_'+anno_test[-1]] = [image_crop_name, anno]
                cv2.imwrite(os.path.join(root_folder, 'WFLW', 'images_test', image_crop_name), image_crop)
                f.write(image_crop_name+' ')
                for x,y in list(anno):
                    f.write(str(x)+' '+str(y)+' ')
                f.write('\n')
                count += 1

        gen_meanface(root_folder, data_name)
    ################################################################################################################
    elif data_name == 'FLSC':
        train_file = 'list_98pt_rect_attr_train.txt'
        with open(os.path.join(root_folder, 'FLSC', 'FLSC_annotations', 'list_98pt_rect_attr_train_test', train_file), 'r') as f:
            annos_train = f.readlines()
        annos_train = [x.strip().split() for x in annos_train]
        count = 1
        with open(os.path.join(root_folder, 'FLSC', 'train.txt'), 'w') as f:
            for anno_train in annos_train:
                image_crop, anno = process_FLSC(anno_train, target_size)
                pad_num = 4-len(str(count))
                image_crop_name = 'flsc_train_' + '0' * pad_num + str(count) + '.jpg'
                print(image_crop_name)
                cv2.imwrite(os.path.join(root_folder, 'FLSC', 'images_train', image_crop_name), image_crop)
                f.write(image_crop_name+' ')
                for x,y in anno:
                    f.write(str(x)+' '+str(y)+' ')
                f.write(anno_train[200])
                f.write('\n')
                count += 1

        test_file = 'list_98pt_rect_attr_test.txt'
        with open(os.path.join(root_folder, 'FLSC', 'FLSC_annotations', 'list_98pt_rect_attr_train_test', test_file), 'r') as f:
            annos_test = f.readlines()
        annos_test = [x.strip().split() for x in annos_test]
        names_mapping = {}
        count = 1
        with open(os.path.join(root_folder, 'FLSC', 'test.txt'), 'w') as f:
            for anno_test in annos_test:
                image_crop, anno = process_FLSC(anno_test, target_size)
                pad_num = 4-len(str(count))
                image_crop_name = 'flsc_test_' + '0' * pad_num + str(count) + '.jpg'
                print(image_crop_name)
                names_mapping[anno_test[0]+'_'+anno_test[-1]] = [image_crop_name, anno]
                cv2.imwrite(os.path.join(root_folder, 'FLSC', 'images_test', image_crop_name), image_crop)
                f.write(image_crop_name+' ')
                for x,y in list(anno):
                    f.write(str(x)+' '+str(y)+' ')
                f.write(anno_train[200])
                f.write('\n')
                count += 1

        gen_meanface(root_folder, data_name)
    ################################################################################################################
    else:
        print('Wrong data!')

if __name__ == '__main__':
    if len(sys.argv) < 2:
        print('please input the data name.')
        print('1. WFLW')
        print('2. FLSC')
        exit(0)
    else:
        data_name = sys.argv[1]
        gen_data('../data', data_name, 256)


