#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import os
from pathlib import Path

import numpy as np
import tqdm
from PIL import Image
import json


def convert(input, output):
    img = np.asarray(Image.open(input))
    assert img.dtype == np.uint8
    # img = img - 1  # 0 (ignore) becomes 255. others are shifted by 1
    # Image.fromarray(img).save(output)

    label = img[:, :, 0]
    output_labels = output / (input.stem + '_sem.png')
    Image.fromarray(label).save(output_labels)

    js = {"imgHeight" : img.shape[0],
            "imgWidth" : img.shape[1]}
        
    output_js = output / (input.name.replace('.png', '.json'))
    with open(output_js, 'w') as fp:
        json.dump(js, fp)


if __name__ == "__main__":
    dataset_dir = Path(os.getenv("DETECTRON2_DATASETS")) / "kitti_step"
    for name in ["train", "val"]:
        annotation_dir = dataset_dir / "panoptic_maps" / name
        # output_dir.mkdir(parents=True, exist_ok=True)
        # import ipdb; ipdb.set_trace()
        for di in tqdm.tqdm(list(annotation_dir.iterdir())):
            output_dir = dataset_dir / "panoptic_maps" / name / di.stem
            for file in list(di.iterdir()):
                if file.suffix == '.json':
                    continue
                # output_file = output_dir / file.name.replace('.png', '.json')
                convert(file, output_dir)
