#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.
import os
from pathlib import Path

import numpy as np
import tqdm
from PIL import Image
import json


def convert(img_path, label_path, output_image_dir, output_label_dir, idx):
    assert img_path.stem == label_path.stem
    img_root = format(idx, "06d")
    assert img_path.stem.split('_')[-1] ==  format(idx + 1, "06d"), str(idx)

    labels = np.asarray(Image.open(label_path))
    assert labels.dtype == np.uint8, label_path

    image = np.asarray(Image.open(img_path))
    assert image.dtype == np.uint8, img_path

    assert image.shape[:2] == labels.shape, "{} != {}".format(image.shape, labels.shape)

    output_image = output_image_dir / (img_root + '.png')
    Image.fromarray(image).save(output_image)

    output_labels = output_label_dir / (img_root + '_sem.png')
    Image.fromarray(labels).save(output_labels)

    js = {"imgHeight" : labels.shape[0],
            "imgWidth" : labels.shape[1]}
        
    output_js = output_label_dir / (img_root + '.json')
    with open(output_js, 'w') as fp:
        json.dump(js, fp)


if __name__ == "__main__":
    dataset_dir = Path(os.getenv("DETECTRON2_DATASETS")) / "ttt_custom"
    for name in ["berkeley", "paris", "house"]:
        annotation_dir = dataset_dir / "visualizations" / (name + "_sem_seg_big_bitmasks")
        annotations = list(annotation_dir.iterdir())
        image_dir = dataset_dir / "original_images" / name
        images = list(image_dir.iterdir())

        # import ipdb; ipdb.set_trace()

        assert len(annotations) == len(images)

        output_label_dir = dataset_dir / "labels_swin_l" / name
        Path.mkdir(output_label_dir, parents=True, exist_ok=True)
        output_image_dir = dataset_dir / "images" / name
        Path.mkdir(output_image_dir, parents=True, exist_ok=True)

        read_images = 0
        for img_path, label_path in tqdm.tqdm(zip(images, annotations)):
            if label_path.suffix == '.json':
                continue
            convert(img_path, label_path, output_image_dir, output_label_dir, read_images)
            read_images += 1
