# -*- coding: utf-8 -*-
"""
Created on Tue Nov  26 12:28:20 2024

@author: Ignacio
"""
import io
import os
import cv2
import glob
import numpy as np
import torch

from tqdm import tqdm

from PIL import Image
from insightface.app import FaceAnalysis


class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, path_dir):
        self.files = glob.glob(path_dir + '/*/*/*')
        # Prepare face detection
        self.app = FaceAnalysis(name='antelopev2', root='./', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.app.prepare(ctx_id=0, det_size=(128, 128))
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        return self.preprocess(self.files[idx])
    def preprocess(self, info_image):
        img = np.array(Image.open(info_image))[:,:,::-1]
        # Detect face and landmarks
        faces = self.app.get(img)
        if not faces:
            return None
        faces = sorted(faces, key=lambda x:(x['bbox'][2]-x['bbox'][0])*(x['bbox'][3]-x['bbox'][1]))[-1]  # select largest face (if more than one detected)
        # Preparing the image to be input to the network
        id_emb = torch.tensor(faces['embedding'], dtype=torch.float16)[None].cuda()
        id_emb = id_emb/torch.norm(id_emb, dim=1, keepdim=True)   # normalize embedding
        name = info_image
        identity = os.path.basename(info_image)
        return id_emb, name, identity

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

ruta = '../../RFW/data/'
test_data = CustomImageDataset(ruta)
test_dataloader = torch.utils.data.DataLoader(test_data, collate_fn=collate_fn, batch_size=500)

x_data_name = []
x_data_id = []
x_embd = []
for data in tqdm(test_dataloader):
    embds, names, ids = data
    x_embd.extend(embds)
    x_data_name.extend(names)
    x_data_id.extend(ids)

np.save('../RFW_names.npy', np.array(x_data_name))
np.save('../RFW_ids.npy', np.array(x_data_id))
np.save('../RFW_arcface.npy', np.array([tensor.detach().cpu().numpy() for tensor in x_embd]))

