# Copyright (c) Facebook, Inc. and its affiliates.
# Copyright (c) Meta Platforms, Inc. All Rights Reserved
# Modified by Feng Liang from https://github.com/MendelXu/zsseg.baseline/blob/master/datasets/prepare_voc_sem_seg.py
# Modified by Heeseong Shin from https://github.com/facebookresearch/ov-seg/blob/main/datasets/prepare_voc_sem_seg.py

import os
import os.path as osp
from pathlib import Path
import tqdm

import numpy as np
from PIL import Image


clsID_to_trID = {
    0: 255,
    1: 0,
    2: 1,
    3: 2,
    4: 3,
    5: 4,
    6: 5,
    7: 6,
    8: 7,
    9: 8,
    10: 9,
    11: 10,
    12: 11,
    13: 12,
    14: 13,
    15: 14,
    16: 15,
    17: 16,
    18: 17,
    19: 18,
    20: 19,
    255: 255,
}
clsID_to_trID_bg = clsID_to_trID.copy()
clsID_to_trID_bg[0] = 20

def convert_to_trainID(
    maskpath, out_mask_dir, is_train, clsID_to_trID=clsID_to_trID, suffix=""
):
    mask = np.array(Image.open(maskpath))
    mask_copy = np.ones_like(mask, dtype=np.uint8) * 255
    for clsID, trID in clsID_to_trID.items():
        mask_copy[mask == clsID] = trID
    seg_filename = (
        osp.join(out_mask_dir, "train" + suffix, osp.basename(maskpath))
        if is_train
        else osp.join(out_mask_dir, "val" + suffix, osp.basename(maskpath))
    )
    if len(np.unique(mask_copy)) == 1 and np.unique(mask_copy)[0] == 255:
        return
    Image.fromarray(mask_copy).save(seg_filename, "PNG")



if __name__ == "__main__":
    dataset_dir = Path(os.getenv("DETECTRON2_DATASETS", "datasets"))
    print('Caution: we only generate the validation set!')
    voc_path = dataset_dir / "VOCdevkit" / "VOC2012"
    out_mask_dir = voc_path / "annotations_detectron2"
    out_mask_dir_bg = voc_path / "annotations_detectron2_bg"
    #out_image_dir = voc_path / "images_detectron2"
    for name in ["val"]:
        os.makedirs((out_mask_dir / name), exist_ok=True)
        os.makedirs((out_mask_dir_bg / name), exist_ok=True)
        #os.makedirs((out_image_dir / name), exist_ok=True)
        val_list = [
            osp.join(voc_path, "SegmentationClassAug", f + ".png")
            for f in np.loadtxt(osp.join(voc_path, "ImageSets/Segmentation/val.txt"), dtype=np.str).tolist()
        ]
        for file in tqdm.tqdm(val_list):
            convert_to_trainID(file, out_mask_dir, is_train=False)
            convert_to_trainID(file, out_mask_dir_bg, is_train=False, clsID_to_trID=clsID_to_trID_bg)