
import io
import json
import numpy as np
import os
from PIL import Image
from pycocotools.coco import COCO
from sklearn.model_selection import train_test_split
import sys

from Config import get_data_dir, id_from_path

sys.path.insert(0, '../Common')
from DataUtils import get_transform

class COCOWrapper():

    def __init__(self, root = '/home/gregory/Datasets/COCO', mode = 'val', year = '2017'):
        stdout_orig = sys.stdout
        sys.stdout = io.StringIO()
        coco = COCO('{}/annotations/instances_{}{}.json'.format(root, mode, year))
        sys.stdout = stdout_orig
        cats = coco.loadCats(coco.getCatIds())
        self.coco = coco
        self.cats = cats
        self.root = root
        self.mode = mode
        self.year = year
        
    def get_source_dir(self):
        return '{}/{}{}'.format(self.root, self.mode, self.year)

    def get_images_with_cats(self, cats):
        if cats is not None:
            cats = [x.replace('+', ' ') for x in cats]
        coco = self.coco
        return coco.loadImgs(coco.getImgIds(catIds = coco.getCatIds(catNms = cats)))  
    
    def get_annotations(self, img_obj):
        coco = self.coco
        return coco.loadAnns(coco.getAnnIds(img_obj['id'], iscrowd = None))

if __name__ == '__main__':
    
    # Helper function to load and reshape an image
    t = get_transform(mode = 'reshape')
    def reshape_image(filename):
        return t(Image.open(filename).convert('RGB'))
    
    # Setup
    base_dir = get_data_dir()
    os.system('rm -rf {}'.format(base_dir))
    os.system('mkdir {}'.format(base_dir))    
    
    # Divide the dataset into train/val/test splits
    # Because COCO's test split is private, we do two things
    # 1) We use the val split for testing 
    # 2) We divide the train split into training and validation splits for training models
    cocos = {}
    imgs = {}
    dirs = {}
    
    coco = COCOWrapper(mode = 'train')
    source_dir = coco.get_source_dir()
    imgs_tmp = coco.get_images_with_cats(None)
    
    cocos['train'] = coco
    cocos['val'] = coco
    imgs['train'], imgs['val'] = train_test_split(imgs_tmp, test_size = 0.1, random_state = 0)
    dirs['train'] = source_dir
    dirs['val'] = source_dir
    
    coco = COCOWrapper(mode = 'val')
    source_dir = coco.get_source_dir()
    imgs_tmp = coco.get_images_with_cats(None) 
    
    cocos['test'] = coco
    imgs['test'] = imgs_tmp
    dirs['test'] = source_dir
    
    # Create the class maps
    cats = coco.cats
    indices = [int(x['id']) for x in cats]
    names = [x['name'].replace(' ', '+') for x in cats]
    
    name2index = {}
    for i, v in enumerate(names):
        name2index[v] = i
    index2name = list(name2index)
        
    with open('{}/maps.json'.format(base_dir), 'w') as f:
        json.dump([name2index, index2name], f)
        
    coco = None
    
    # Process the splits
    for mode in ['test', 'val', 'train']:
        mode_dir = '{}/{}'.format(base_dir, mode)
        os.system('mkdir {}'.format(mode_dir))
    
        # Reshape the images and get their labels
        image_dir = '{}/images'.format(mode_dir)
        os.system('mkdir {}'.format(image_dir))
        images = {}
        for img in imgs[mode]:
            img_id = id_from_path(img['file_name'])
            img_path = '{}/{}.jpg'.format(image_dir, img_id)
            
            img_data = reshape_image('{}/{}'.format(dirs[mode], img['file_name']))
            img_data.save(img_path)
            
            anns = cocos[mode].get_annotations(img)
            label = np.zeros((91))  # Each 'label' vector is large enough for easy indexing, but this means it contains unused indices
            for ann in anns:
                label[ann['category_id']] = 1.0
            label = label[indices] # Remove the unused indices
            label = list(label)
            
            images[img_id] = {'file': img_path, 'label': label}
                    
        with open('{}/images.json'.format(mode_dir), 'w') as f:
            json.dump(images, f)

        # Create a map from each object to images with that object
        ids_all = [id_from_path(img['file_name']) for img in imgs[mode]]
           
        name2ids = {}
        for name in names:
            tmp = [id_from_path(img['file_name']) for img in cocos[mode].get_images_with_cats([name])]
            name2ids[name] = list(np.intersect1d(ids_all, tmp))
            
        with open('{}/name2ids.json'.format(mode_dir), 'w') as f:
            json.dump(name2ids, f)
                        