import json
import os
import copy
from pathlib import Path

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.data.datasets.coco import load_sem_seg

SCANNETPP_DIR = Path("./data/scannet++/")
CLASSES_FILE = open(SCANNETPP_DIR / "semantic_classes.json")
SCANNET_CATEGORIES = json.load(CLASSES_FILE)


def _get_categories_by_split(root, split_name):
    if split_name == "val":
        return SCANNET_CATEGORIES[:1634]

    with open(os.path.join(root, f"semantic_classes_{split_name}.txt")) as f:
        lines = f.readlines()
        labels_int = [int(line.strip()) for line in lines]
    
    return [SCANNET_CATEGORIES[label_int] for label_int in labels_int]

def _get_stuff_colors():
    path = "./ov-seg-clip/open_vocab_seg/data/datasets/scannetpp_stuff_colors.txt"
    stuff_colors = []
    with open(path) as f:
        lines = f.readlines()
        for line in lines:
            colors = line.strip().split(' ')
            colors = [float(c) for c in colors]
            stuff_colors.append(colors)
    return stuff_colors


def _get_scannetpp_meta(categories):
    meta = {}
    # The following metadata maps contiguous id from [0, #thing categories +
    # #stuff categories) to their names and colors. We have to replica of the
    # same name and color under "thing_*" and "stuff_*" because the current
    # visualization function in D2 handles thing and class classes differently
    # due to some heuristic used in Panoptic FPN. We keep the same naming to
    # enable reusing existing visualization functions.
    stuff_classes = [k["name"] for k in categories]

    # Convert category id for training:
    #   category id: like semantic segmentation, it is the class id for each
    #   pixel. Since there are some classes not used in evaluation, the category
    #   id is not always contiguous and thus we have two set of category ids:
    #       - original category id: category id in the original dataset, mainly
    #           used for evaluation.
    #       - contiguous category id: [0, #classes), in order to train the linear
    #           softmax classifier.

    stuff_ids = [k["id"] for k in categories]
    stuff_dataset_id_to_contiguous_id = {k: i for i, k in enumerate(stuff_ids)}

    stuff_colors = {}
    stuff_colors_list = _get_stuff_colors()
    for id, classname in zip(stuff_ids, stuff_classes):
        stuff_colors[classname] = stuff_colors_list[id % len(stuff_colors_list)]

    meta = {
        "stuff_ids": stuff_ids,
        "stuff_dataset_id_to_contiguous_id": stuff_dataset_id_to_contiguous_id,
        "stuff_classes": stuff_classes,
        "stuff_colors": stuff_colors,
    }
    return meta


def register_scannetpp(root):
    root = os.path.join(root, "scannet++")
    # previously: training set 0cf2e9402d, test 290ef3f2c9
    # now a group of scenes read from scenes_train and scenes_val
    # to generate the training set, you need to run datasets/prepare_scannetpp.py
    for name, dirname in [("train", "train"), ("val", "val")]:
        image_dir = os.path.join(root, dirname, "iphone", "rgb")
        gt_dir = os.path.join(root, dirname, "iphone", "render_semantic_id")
        categories = _get_categories_by_split(root, name)
        meta = _get_scannetpp_meta(categories)
        name = f"scannetpp_sem_seg_{name}"
        DatasetCatalog.register(
            name, lambda x=image_dir, y=gt_dir: load_sem_seg(y, x, gt_ext="png", image_ext="jpg")
        )
        MetadataCatalog.get(name).set(
            # stuff_classes=[x["name"] for x in SCANNET_CATEGORIES],
            image_root=image_dir,
            sem_seg_root=gt_dir,
            evaluator_type="sem_seg",
            ignore_label=65535,
            **meta,
        )

_root = os.getenv("DETECTRON2_DATASETS", "datasets")
register_scannetpp(_root)