import os
import numpy as np
import xml.etree.ElementTree as ET
from typing import List, Tuple, Union

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager

__all__ = ["load_bdd100k_instances", "register_bdd100k"]

# The target classes for BDD100K
CLASS_NAMES = ("truck", "car", "rider", "person", "motor", "bicycle", "bus")


def load_bdd100k_instances(
    dirname: str,
    split: str,
    class_names: Union[List[str], Tuple[str, ...]] = CLASS_NAMES
):
    """
    Load BDD100K annotations in Detectron2 format.

    dirname: Root directory containing:
        bdd100k/
        ├── Annotations/
        ├── JPEGImages/
        └── ImageSets/
            ├── train.txt
            └── test.txt
    split: 'train' or 'test'
    """
    # Read the image ids for this split
    list_file = os.path.join(dirname, "ImageSets", f"{split}.txt")
    with PathManager.open(list_file) as f:
        fileids = np.loadtxt(f, dtype=str)

    dicts = []
    ann_dir = PathManager.get_local_path(os.path.join(dirname, "Annotations"))
    img_dir = os.path.join(dirname, "JPEGImages")

    for fileid in fileids:
        anno_file = os.path.join(ann_dir, fileid + ".xml")
        img_file = os.path.join(img_dir, fileid + ".jpg")

        # Parse XML annotation
        with PathManager.open(anno_file) as f:
            tree = ET.parse(f)
        root = tree.getroot()

        record = {
            "file_name": img_file,
            "image_id": fileid,
            "height": int(root.findtext("./size/height")),
            "width": int(root.findtext("./size/width")),
        }

        annos = []
        for obj in root.findall("object"):
            cls = obj.findtext("name")
            if cls not in class_names:
                continue
            bbox = obj.find("bndbox")
            coords = [
                float(bbox.findtext(x)) for x in ("xmin", "ymin", "xmax", "ymax")
            ]
            # Convert to 0-based indexing
            coords[0] -= 1.0
            coords[1] -= 1.0
            annos.append({
                "category_id": class_names.index(cls),
                "bbox": coords,
                "bbox_mode": BoxMode.XYXY_ABS,
            })

        record["annotations"] = annos
        dicts.append(record)

    return dicts


def register_bdd100k(name: str, dirname: str, split: str):
    """
    Register the BDD100K dataset to Detectron2.

    name: e.g. 'bdd100k_train'
    dirname: path to the 'bdd100k' root folder
    split: 'train' or 'test'
    """
    DatasetCatalog.register(
        name,
        lambda: load_bdd100k_instances(dirname, split)
    )
    MetadataCatalog.get(name).set(
        thing_classes=list(CLASS_NAMES),
        dirname=dirname,
        split=split
    )
