
from collections import defaultdict
import json
import numpy as np
import os
import pandas as pd
from PIL import Image
import sys

from Config import get_source_dir, get_data_dir, id_from_path

sys.path.insert(0, '../Common')
from DataUtils import get_transform

if __name__ == '__main__':
    
    # Helper function to load and reshape an image
    t = get_transform(mode = 'reshape')
    def reshape_image(filename):
        return t(Image.open(filename).convert('RGB'))
    
    # Setup
    source_dir = get_source_dir()
    base_dir = get_data_dir()
    os.system('rm -rf {}'.format(base_dir))
    os.system('mkdir {}'.format(base_dir))    
    
    # Get the split for each image
    tmp = pd.read_csv('{}/list_eval_partition.csv'.format(source_dir)).to_numpy()
    
    imgs = {'train': [], 'val': [], 'test': []}
    names = list(imgs)
    
    for i in range(tmp.shape[0]):
        imgs[names[tmp[i, 1]]].append(id_from_path(tmp[i, 0]))
        
    # Load the label info
    tmp = pd.read_csv('{}/list_attr_celeba.csv'.format(source_dir))
    names = [x.replace('_', '+') for x in list(tmp.columns)[1:]]
    
    # Create the class maps 
    name2index = {}
    for i, v in enumerate(names):
        name2index[v] = i
    index2name = list(name2index)
        
    with open('{}/maps.json'.format(base_dir), 'w') as f:
        json.dump([name2index, index2name], f)
    
    # Get the labels for each image
    tmp = tmp.to_numpy()
    id2labels = {}
    name2ids = defaultdict(list)
    for i in range(tmp.shape[0]):
        img_id = id_from_path(tmp[i, 0])
        label = 1 * (tmp[i, 1:] == 1)  
        
        for index in np.where(label == 1)[0]:
            name2ids[index2name[index]].append(img_id)
            
        id2labels[img_id] = [int(v) for v in label]
        
    # Process the splits
    for mode in ['test', 'val', 'train']:   
        mode_dir = '{}/{}'.format(base_dir, mode)
        os.system('mkdir {}'.format(mode_dir))
        
        ids_mode = imgs[mode]

        # Reshape the images and get their labels
        image_dir = '{}/images'.format(mode_dir)
        os.system('mkdir {}'.format(image_dir))
        images = {}
        for img_id in ids_mode:
            img_path = '{}/{}.jpg'.format(image_dir, img_id)
            
            img_data = reshape_image('{}/img_align_celeba/img_align_celeba/{}.jpg'.format(source_dir, img_id))
            img_data.save(img_path)
            
            label = id2labels[img_id]
            
            images[img_id] = {'file': img_path, 'label': label}

        with open('{}/images.json'.format(mode_dir), 'w') as f:
            json.dump(images, f)
                   
        # Create a map from each attribute to iamges with that attribute
        name2ids_mode = {}
        for name in names:
            name2ids_mode[name] = list(np.intersect1d(name2ids[name], ids_mode))
        with open('{}/name2ids.json'.format(mode_dir), 'w') as f:
            json.dump(name2ids_mode, f)
            