import numpy as np
import os
import xml.etree.ElementTree as ET
from typing import List, Tuple, Union

from detectron2.data import DatasetCatalog, MetadataCatalog
from detectron2.structures import BoxMode
from detectron2.utils.file_io import PathManager

__all__ = ["load_visdrone_instances", "register_visdrone"]

# The five target classes
CLASS_NAMES = ("person", "bicycle", "car", "bus", "motor")

def load_visdrone_instances(
    dirname: str,
    split: str,
    class_names: Union[List[str], Tuple[str, ...]] = CLASS_NAMES
):
    """
    Load VisDrone-DET annotations in Detectron2 format.

    dirname: Root directory containing:
        visdrone/
        ├── Annotations/
        ├── JPEGImages/
        └── ImageSets/Main/{split}.txt
    split: 'train' or 'val'
    """
    list_file = os.path.join(dirname, "ImageSets", "Main", f"{split}.txt")
    with PathManager.open(list_file) as f:
        fileids = np.loadtxt(f, dtype=str)

    dicts = []
    ann_dir = PathManager.get_local_path(os.path.join(dirname, "Annotations"))
    img_dir = os.path.join(dirname, "JPEGImages")

    for fileid in fileids:
        anno_file = os.path.join(ann_dir, fileid + ".xml")
        jpeg_file = os.path.join(img_dir, fileid + ".jpg")

        with PathManager.open(anno_file) as f:
            tree = ET.parse(f)
        root = tree.getroot()

        record = {
            "file_name": jpeg_file,
            "image_id": fileid,
            "height": int(root.find("./size/height").text),
            "width":  int(root.find("./size/width").text),
        }

        annos = []
        for obj in root.findall("object"):
            cls = obj.find("name").text
            if cls not in class_names:
                continue
            bbox = obj.find("bndbox")
            coords = [float(bbox.find(x).text) for x in ("xmin","ymin","xmax","ymax")]
            coords[0] -= 1.0  # to 0-based
            coords[1] -= 1.0
            annos.append({
                "category_id": class_names.index(cls),
                "bbox": coords,
                "bbox_mode": BoxMode.XYXY_ABS,
            })

        record["annotations"] = annos
        dicts.append(record)

    return dicts

def register_visdrone(name: str, dirname: str, split: str):
    """
    Register the VisDrone dataset to Detectron2.

    name: e.g. 'visdrone_train'
    dirname: path to the lowercase 'visdrone' root folder
    split: 'train' or 'val'
    """
    DatasetCatalog.register(
        name,
        lambda: load_visdrone_instances(dirname, split)
    )
    MetadataCatalog.get(name).set(
        thing_classes=list(CLASS_NAMES),
        dirname=dirname,
        split=split
    )
