# -*- coding: utf-8 -*-
"""
Created on Tue Nov  26 12:28:20 2024

@author: Ignacio
"""
import io
import os
import cv2
import glob
import numpy as np
import torch
import torchvision.transforms as T

from tqdm import tqdm
from dissect import net

from PIL import Image
from skimage import transform
from insightface.app import FaceAnalysis

network = 'r100'
# weights1 ='../MODELS/curricularface/CurricularFace_Backbone.pth'
weights1 ='../MODELS/ElasticFace-Arc/295672backbone.pth'
weights2 ='../MODELS/ms1mv3_arcface_r100_fp16/backbone.pth'
# weights ='../MODELS/ms1mv3_arcface_r50_fp16/backbone.pth'
weights3 ='../MODELS/glint360k_cosface_r100_fp16_0.1/backbone.pth'
# weights ='../MODELS/glint360k_cosface_r50_fp16_0.1/backbone.pth'
weights4 ='../MODELS/partial_fc_glint360k_r100/16_backbone.pth'
# weights ='../MODELS/partial_fc_glint360k_r50/16backbone.pth'

if torch.cuda.is_available():
    device = "cuda:0"
else:
    device = "cpu"

device = torch.device(device)

# For data standardization
src = np.array([
    [30.2946, 51.6963],
    [65.5318, 51.5014],
    [48.0252, 71.7366],
    [33.5493, 92.3655],
    [62.7299, 92.2041]], dtype=np.float32)
src[:, 0] += 8.0


class CustomImageDataset(torch.utils.data.Dataset):
    def __init__(self, path_dir):
        self.files = glob.glob(path_dir + '/*')
        # Prepare face detection
        self.app = FaceAnalysis(allowed_modules=['detection'], providers=['CUDAExecutionProvider', 'CPUExecutionProvider'])
        self.app.prepare(ctx_id=0, det_size=(128, 128))
        self.tform = transform.SimilarityTransform()
        self.transf = T.Compose([
            T.ToTensor(),
            # T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])
    def __len__(self):
        return len(self.files)
    def __getitem__(self, idx):
        return self.preprocess(self.files[idx])
    def preprocess(self, info_image):
        image = cv2.imread(info_image)
        # Detect face and landmarks
        salida = self.app.get(image)
        if salida:
            salida = salida[0]
        else:
            return None
        landmark = salida['kps']       
        # Standardizing, Aligning and Cropping the face
        self.tform.estimate(landmark, src)
        M = self.tform.params[0:2, :]
        image = cv2.warpAffine(image, M, (112, 112))
        # Preparing the image to be input to the network
        image = self.transf(image)
        name = info_image
        identity = os.path.basename(info_image)
        return image, name, identity

def collate_fn(batch):
    batch = list(filter(lambda x: x is not None, batch))
    return torch.utils.data.dataloader.default_collate(batch)

# Prepare model
modelos = []
for weights in [weights1,weights2,weights3,weights4]:
    modelo = net(network, weights)
    modelo = modelo.to(device)
    modelos.append(modelo)

path = 'img_align_celeba'
test_data = CustomImageDataset(path)
test_dataloader = torch.utils.data.DataLoader(test_data, collate_fn=collate_fn, batch_size=500)

x_data_name = []
x_data_id = []
x_embd = [[],[],[],[]]
for data in tqdm(test_dataloader):
    batch, names, ids = data
    batch = batch.to(device)
    with torch.no_grad():
        for i, modelo in enumerate(modelos):
            x_embd[i].extend(torch.nn.functional.normalize(modelo(batch)))
        x_data_name.extend(names)
        x_data_id.extend(ids)

np.save('names.npy', np.array(x_data_name))
np.save('ids.npy', np.array(x_data_name))
for i, model in enumerate(['embd_cosface_r100','embd_arcface_r100','embd_partialFC_r100', 'embd_elasticface_r100']):
    np.save(model, np.array([tensor.detach().cpu().numpy() for tensor in x_embd[i]]))

