import torch
import facer
from PIL import Image

# 只保留最大的脸
def only_one_face(faces):
    max = 0
    max_id = -1
    for i in range(faces['rects'].size(0)):
        x1, y1, x2, y2 = faces['rects'][i]
        # print(x1, y1, x2, y2)
        area = (x2 - x1) * (y2 - y1)
        if area > max:
            max = area
            max_id = i
    for key, value in faces.items():
        faces[key] = value[max_id].unsqueeze(0)

# 返回面部区域（不包括完整的头发）
def to_rectangle(faces):
    rects = faces['rects'][0]
    x1, y1, x2, y2 = rects
    x1 = torch.floor(x1).to(torch.int).item()
    y1 = torch.floor(y1).to(torch.int).item()
    x2 = torch.ceil(x2).to(torch.int).item()
    y2 = torch.ceil(y2).to(torch.int).item()
    cropped_img = parse_img[y1:y2, x1:x2, :]
    return cropped_img

device = 'cuda' if torch.cuda.is_available() else 'cpu'

image = facer.hwc2bchw(facer.read_hwc('data/twogirls.jpg')).to(device=device)  # image: 1 x 3 x h x w

face_detector = facer.face_detector('retinaface/mobilenet', device=device)
with torch.inference_mode():
    faces = face_detector(image)

only_one_face(faces)

face_parser = facer.face_parser('farl/lapa/448', device=device) # optional "farl/celebm/448"

with torch.inference_mode():
    faces = face_parser(image, faces)

seg_logits = faces['seg']['logits']
seg_probs = seg_logits.softmax(dim=1)  # nfaces x nclasses x h x w
n_classes = seg_probs.size(1)
vis_seg_probs = seg_probs.argmax(dim=1).float()/n_classes*255
vis_img = vis_seg_probs.sum(0, keepdim=True)
parse_img = facer.get_bhw(vis_img)
print(parse_img.shape)
pimage = Image.fromarray(parse_img.cpu().numpy())
pimage.save('result.png')
# facer.show_bhw(vis_img)
# facer.show_bchw(facer.draw_bchw(image, faces))