# -*- coding: utf-8 -*-
"""
Created on Tue Nov  26 12:28:20 2024

@author: Ignacio
"""
import io
import os
import cv2
import glob
import numpy as np
import torch
import torchvision.transforms as T

from tqdm import tqdm
from dissect import net

from PIL import Image
from skimage import transform
from insightface.app import FaceAnalysis

network = 'r100'
# weights ='MODELS/ElasticFace-Cos+/295672backbone.pth'
# weights ='MODELS/ms1mv3_arcface_r100_fp16/backbone.pth'
# weights ='MODELS/ms1mv3_arcface_r50_fp16/backbone.pth'
# weights ='MODELS/glint360k_cosface_r100_fp16_0.1/backbone.pth'
# weights ='MODELS/glint360k_cosface_r50_fp16_0.1/backbone.pth'
weights ='MODELS/partial_fc_glint360k_r100/16_backbone.pth'
# weights ='MODELS/partial_fc_glint360k_r50/16backbone.pth'

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

device = torch.device(device)

# For data standardization
src = np.array([
    [30.2946, 51.6963],
    [65.5318, 51.5014],
    [48.0252, 71.7366],
    [33.5493, 92.3655],
    [62.7299, 92.2041]], dtype=np.float32)
src[:, 0] += 8.0


class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, path_dir):
        self.files = glob.glob(path_dir + '/*/*/*')
        # Prepare face detection
        self.app = FaceAnalysis(allowed_modules=['detection'], providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.app.prepare(ctx_id=0, det_size=(128, 128))
        self.tform = transform.SimilarityTransform()
        self.transf = T.Compose([
            T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        return self.preprocess(self.files[idx])
    def preprocess(self, info_image):
        image = cv2.imread(info_image)
        # Detect face and landmarks
        salida = self.app.get(image)
        if salida:
            salida = salida[0]
        else:
            return None
        landmark = salida['kps']       
        # Standardizing, Aligning and Cropping the face
        self.tform.estimate(landmark, src)
        M = self.tform.params[0:2, :]
        image = cv2.warpAffine(image, M, (112, 112))
        # Preparing the image to be input to the network
        image = self.transf(image)
        name = info_image
        identity = os.path.basename(os.path.dirname(info_image))
        return image, name, identity

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

# Prepare model
modelo = net(network, weights)
modelo = modelo.to(device)

path = 'RFW/data'
test_data = CustomImageDataset(path)
test_dataloader = torch.utils.data.DataLoader(test_data, collate_fn=collate_fn, batch_size=500)

x_data_name = []
x_data_id = []
x_embd = []
for data in tqdm(test_dataloader):
    batch, names, ids = data
    batch = batch.to(device)
    with torch.no_grad():
        x_embd.extend(modelo(batch))
        x_data_name.extend(names)
        x_data_id.extend(ids)

np.save('RFW/embd_partialFC.npy', np.array([tensor.detach().cpu().numpy() for tensor in x_embd]))
np.save('RFW/x_names.npy', np.array(x_data_name))
np.save('RFW/x_ids.npy', np.array(x_data_id))
