import sys
import os

current_dir = os.path.dirname(os.path.abspath(__file__))
parent_dir = os.path.abspath(os.path.join(current_dir, os.pardir))
sys.path.append(parent_dir)

from tqdm import tqdm
import torch
from PIL import Image
from FaRL import facer
import numpy as np


device = 'cuda' if torch.cuda.is_available() else 'cpu'
face_detector = facer.face_detector('retinaface/mobilenet', device=device)
face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"
face_aligner = facer.face_aligner('farl/wflw/448', device=device)

def get_face_parsing(jpg_path, save_path):
    # 只保留最大的那个脸
    def only_one_face(faces):
        if faces['rects'].size(0) == 1:
            return faces
        max = 0
        max_id = -1
        for i in range(faces['rects'].size(0)):
            x1, y1, x2, y2 = faces['rects'][i]
            # print(x1, y1, x2, y2)
            area = (x2 - x1) * (y2 - y1)
            if area > max:
                max = area
                max_id = i
        for key, value in faces.items():
            faces[key] = value[max_id].unsqueeze(0)
    # parsing 如果保存成 (1, 512, 512) 的 .npy 文件, 使用 show 方法可视化
    def show(show_path):
        show_folder = './data/show'
        from_saved = np.load(show_path)
        from_saved  = facer.get_bhw(torch.from_numpy(from_saved))
        pimage = Image.fromarray(from_saved.cpu().numpy())
        name = show_path.split('/')[-1].split('.')[0] + '.png'
        pimage.save(os.path.join(show_folder, name))

    image = facer.hwc2bchw(facer.read_hwc(jpg_path)).to(device=device)  # image: 1 x 3 x h x w
    with torch.inference_mode():
        faces = face_detector(image)
    if faces['rects'].size(0) == 0:
        with open('./data_process/parsing_fail_list_new_color.txt', 'a') as fp:
            print(f'{save_path} has no faces')
            fp.write(save_path + '\r\n')
            fp.close()
            return
    only_one_face(faces)
    with torch.inference_mode():
        parsings = face_parser(image, faces)
        alignments = face_aligner(image, faces)
    seg_logits = parsings['seg']['logits']
    seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
    n_classes = seg_probs.size(1)
    vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
    vis_img = vis_seg_probs.sum(0, keepdim=True)
    # print(vis_img.shape)
    parse_img = facer.get_bhw_no_contour(vis_img)
    
    img = parse_img
    for pts in alignments['alignment']:
        img = facer.draw_landmarks_only_eyes(img, None, pts.cpu().numpy(), color=(105, 105, 105))

    pimage = Image.fromarray(img)
    pimage.save(save_path)
    # show(show_path=parsing_path)


if __name__ == '__main__':
    data_file = 'data'
    jpgs_file = os.path.join(data_file, 'jpgs')
    video_names = sorted(os.listdir(jpgs_file))
    ext = '.png'
    idx_list = list(range(len(video_names)))
    idx_list = idx_list[::-1]
    for i in idx_list:
        video_name = video_names[i]

        save_folder = os.path.join(data_file, 'parsing_align_no_contour_new_color', video_name)
        if os.path.exists(save_folder):
            print(f'{save_folder} has been processed, continue ...')
            continue
        os.makedirs(save_folder, exist_ok=True)

        video_path = os.path.join(jpgs_file, video_name)
        jpgs = sorted(os.listdir(video_path))
        print(f'start parsing {video_name}, save to {save_folder}, {i + 1} / {len(video_names)}')
        for jpg in tqdm(jpgs):
            jpg_path = os.path.join(video_path, jpg)
            save_path = os.path.join(save_folder, jpg.split('.')[0] + ext)
            get_face_parsing(jpg_path, save_path)