import os
import torch
from torchvision import models, transforms
from tqdm import tqdm
from PIL import Image
import numpy as np

activation = {}
def get_activation(name):
    def hook(model, input, output):
        activation[name] = output.detach()
    return hook

def load_image_batch(img_path_list):
    img_batch = []
    for img_path in img_path_list:
        img = Image.open(img_path).convert('RGB')
        img_preprocessed = preprocess(img)
        img_batch.append(img_preprocessed)

    return torch.stack(img_batch, dim=0)

def save_feat_batch(feat_arr, feat_path_list):
    for feat_path, feat in zip(feat_path_list, feat_arr):
        np.save(feat_path, feat)

img_directory = 'img_align_celeba/img_align_celeba'
feature_directory = 'img_align_celeba/feat_align_celeba'

batch_size = 32

feat_path_batch = []
img_path_batch = []

resnet = models.resnet50(pretrained=True)
resnet.cuda()
resnet.eval()
resnet.avgpool.register_forward_hook(get_activation('avgpool'))

preprocess = transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            transforms.Normalize(
                mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225])
            ])

for file in tqdm(os.listdir(img_directory)):
    img_path = os.path.join(img_directory, file)
    assert os.path.isfile(img_path) and img_path[-4:]=='.jpg'
    feat_path = os.path.join(feature_directory, file[:-3] + 'npy')

    img_path_batch.append(img_path)
    feat_path_batch.append(feat_path)

    if len(img_path_batch)==batch_size:
        batch_img_tensor = load_image_batch(img_path_batch)

        out = resnet(batch_img_tensor.cuda())
        feat_arr = activation['avgpool'].squeeze().detach().cpu().numpy()

        save_feat_batch(feat_arr, feat_path_batch)

        img_path_batch = []
        feat_path_batch = []
        activation = {}
