import torch
from insightface.app import FaceAnalysis
import numpy as np


class FaceEvaluator():
    def __init__(self, max_num=1, gpu_id=0):
        self.face_model = FaceAnalysis(providers=[('CUDAExecutionProvider', {'device_id': gpu_id}), 'CPUExecutionProvider'])
        self.face_model.prepare(ctx_id=0, det_size=(640, 640))
        self.max_num = max_num
        
    def pil_to_cv2(self, pil_img):
        img = np.array(pil_img)[:,:,::-1]
        return img 
    
    def get_face_emb(self, img):
        if type(img) is not np.ndarray:
            img = self.pil_to_cv2(img)
        faces = self.face_model.get(img, max_num=self.max_num)
        
        if len(faces) == 0:
            return None
        
        faces_feat = [torch.Tensor(f['embedding']) for f in faces]
        return torch.stack(faces_feat)
    
    def get_all_face_emb(self, images):
        emb_list = [self.get_face_emb(im) for im in images]
        emb_list = [a for a in emb_list if a is not None]
        if len(emb_list) == 0:
            return torch.Tensor()
        faces_emb = torch.concat(emb_list, dim=0)
        return faces_emb
        
    def face_similarity(self, gen_imgs, ref_imgs):
        """ 
        calcualte face similarity using insightface
        """
        feat1 = self.get_all_face_emb(gen_imgs)
        if len(feat1) == 0: # no face found
            return 0 
        
        feat2 = self.get_all_face_emb(ref_imgs)
        if len(feat2) == 0: # no face found
            return 0 
        
        feat1 /= feat1.norm(dim=-1, keepdim=True)
        feat2 /= feat2.norm(dim=-1, keepdim=True)
        similarity = feat1 @ feat2.T
        return similarity.max().item()
