# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.

import logging
import numpy as np
import os
import tempfile
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict
from functools import lru_cache
import torch

from detectron2.data import MetadataCatalog
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager

from .evaluator import DatasetEvaluator


class BDD100KDetectionEvaluator(DatasetEvaluator):
    """
    Evaluate Pascal VOC–style AP for the BDD100K dataset.
    Uses the standard VOC devkit evaluation protocol.
    """

    def __init__(self, dataset_name: str):
        """
        Args:
            dataset_name (str): e.g. "bdd100k_train" or "bdd100k_test"
        """
        self._dataset_name = dataset_name
        meta = MetadataCatalog.get(dataset_name)

        # Local copy of annotations for speed
        anno_dir = PathManager.get_local_path(
            os.path.join(meta.dirname, "Annotations")
        )
        self._anno_file_template = os.path.join(anno_dir, "{}.xml")
        # ImageSets/{split}.txt
        self._image_set_path = os.path.join(
            meta.dirname, "ImageSets", meta.split + ".txt"
        )
        self._class_names = meta.thing_classes
        # Use VOC07 metric style
        self._use_07_metric = True
        self._cpu_device = torch.device("cpu")
        self._logger = logging.getLogger(__name__)

    def reset(self):
        # Predictions per class_id -> list of formatted strings
        self._predictions = defaultdict(list)

    def process(self, inputs, outputs):
        """
        Called once per batch. Accumulates predictions.
        """
        for inp, out in zip(inputs, outputs):
            image_id = inp["image_id"]
            instances = out["instances"].to(self._cpu_device)
            boxes = instances.pred_boxes.tensor.numpy()
            scores = instances.scores.tolist()
            classes = instances.pred_classes.tolist()
            for box, score, cls in zip(boxes, scores, classes):
                # Convert to 1-based inclusive coords
                xmin, ymin, xmax, ymax = box
                xmin += 1.0
                ymin += 1.0
                self._predictions[cls].append(
                    f"{image_id} {score:.3f} {{xmin:.1f}} {{ymin:.1f}} {{xmax:.1f}} {{ymax:.1f}}"
                    .format(xmin=xmin, ymin=ymin, xmax=xmax, ymax=ymax)
                )

    def evaluate(self):
        """
        Returns:
            dict with key "bbox" mapping to AP metrics:
              {"AP", "AP50", "AP75", "class-AP50": [... ]}
        """
        # Gather across ranks
        all_preds = comm.gather(self._predictions, dst=0)
        if not comm.is_main_process():
            return
        # Merge
        preds = defaultdict(list)
        for part in all_preds:
            for clsid, lines in part.items():
                preds[clsid].extend(lines)
        del all_preds

        aps = defaultdict(list)
        with tempfile.TemporaryDirectory(prefix="voc_bdd100k_eval_") as tmpdir:
            res_template = os.path.join(tmpdir, "{}.txt")
            for cls_id, cls_name in enumerate(self._class_names):
                lines = preds.get(cls_id, [])
                # Write prediction file for this class
                with open(res_template.format(cls_name), "w") as f:
                    f.write("\n".join(lines))

                # Evaluate at multiple IoU thresholds
                for t in range(50, 100, 5):
                    rec, prec, ap = voc_eval(
                        res_template,
                        self._anno_file_template,
                        self._image_set_path,
                        cls_name,
                        ovthresh=t / 100.0,
                        use_07_metric=self._use_07_metric,
                    )
                    aps[t].append(ap * 100)

        # Aggregate results
        mAP = {t: np.mean(v) for t, v in aps.items()}
        result = OrderedDict()
        result["bbox"] = {
            "AP":   np.mean(list(mAP.values())),
            "AP50": mAP[50],
            "AP75": mAP[75],
            "class-AP50": aps[50],
        }
        return result


# --------------------------------------------------------
# Below: VOC devkit evaluation implementation (unchanged)
# --------------------------------------------------------

@lru_cache(maxsize=None)
def parse_rec(filename):
    """Parse a PASCAL VOC–style XML file."""
    with PathManager.open(filename) as f:
        tree = ET.parse(f)
    objs = []
    for obj in tree.findall("object"):
        struct = {
            "name":      obj.find("name").text,
            "pose":      obj.find("pose").text,
            "truncated": int(obj.find("truncated").text),
            "difficult": int(obj.find("difficult").text),
            "bbox": [
                int(obj.find("bndbox/xmin").text),
                int(obj.find("bndbox/ymin").text),
                int(obj.find("bndbox/xmax").text),
                int(obj.find("bndbox/ymax").text),
            ],
        }
        objs.append(struct)
    return objs


def voc_ap(rec, prec, use_07_metric=False):
    """Compute VOC AP given precision and recall."""
    if use_07_metric:
        ap = 0.0
        for t in np.arange(0.0, 1.1, 0.1):
            p = np.max(prec[rec >= t]) if np.any(rec >= t) else 0.0
            ap += p / 11.0
    else:
        mrec = np.concatenate(([0.0], rec, [1.0]))
        mpre = np.concatenate(([0.0], prec, [0.0]))
        for i in range(mpre.size - 1, 0, -1):
            mpre[i-1] = max(mpre[i-1], mpre[i])
        i = np.where(mrec[1:] != mrec[:-1])[0]
        ap = np.sum((mrec[i+1] - mrec[i]) * mpre[i+1])
    return ap


def voc_eval(
    detpath, annopath, imagesetfile, classname,
    ovthresh=0.5, use_07_metric=False
):
    """
    Top-level PASCAL VOC evaluation.
    detpath:      template for detection results file (detpath.format(classname))
    annopath:     template for annotation xml (annopath.format(img_id))
    imagesetfile: list of image IDs, one per line
    classname:    str
    ovthresh:     IoU threshold
    """
    # Load image IDs
    with PathManager.open(imagesetfile) as f:
        image_ids = [x.strip() for x in f]
    # Load ground truth
    recs = {img: parse_rec(annopath.format(img)) for img in image_ids}

    # Extract gt for this class
    class_recs = {}
    npos = 0
    for img in image_ids:
        R = [obj for obj in recs[img] if obj["name"] == classname]
        bbox = np.array([o["bbox"] for o in R])
        difficult = np.array([o["difficult"] for o in R], dtype=bool)
        det = [False] * len(R)
        npos += np.sum(~difficult)
        class_recs[img] = {"bbox": bbox, "difficult": difficult, "det": det}

    # Read detections
    with open(detpath.format(classname)) as f:
        det_lines = [l.strip().split() for l in f if l.strip()]
    if not det_lines:
        return np.array([]), np.array([]), 0.0

    image_ids_det = [l[0] for l in det_lines]
    scores = np.array([float(l[1]) for l in det_lines])
    BB = np.array([[float(x) for x in l[2:]] for l in det_lines])

    # Sort by score descending
    idx = np.argsort(-scores)
    BB = BB[idx]
    image_ids_det = [image_ids_det[i] for i in idx]

    tp = np.zeros(len(image_ids_det))
    fp = np.zeros(len(image_ids_det))

    # Match detections to GT
    for i, img in enumerate(image_ids_det):
        R = class_recs[img]
        bb = BB[i]
        ovmax = -np.inf
        BBGT = R["bbox"].astype(float)

        if BBGT.size > 0:
            ixmin = np.maximum(BBGT[:,0], bb[0])
            iymin = np.maximum(BBGT[:,1], bb[1])
            ixmax = np.minimum(BBGT[:,2], bb[2])
            iymax = np.minimum(BBGT[:,3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
            ih = np.maximum(iymax - iymin + 1.0, 0.0)
            inters = iw * ih
            uni = (
                (bb[2]-bb[0]+1.0)*(bb[3]-bb[1]+1.0)
                + (BBGT[:,2]-BBGT[:,0]+1.0)*(BBGT[:,3]-BBGT[:,1]+1.0)
                - inters
            )
            overlaps = inters / uni
            ovmax = overlaps.max()
            jmax = overlaps.argmax()

        if ovmax > ovthresh:
            if not R["difficult"][jmax]:
                if not R["det"][jmax]:
                    tp[i] = 1.0
                    R["det"][jmax] = True
                else:
                    fp[i] = 1.0
        else:
            fp[i] = 1.0

    # Compute precision-recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos) if npos > 0 else np.zeros_like(tp)
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)

    ap = voc_ap(rec, prec, use_07_metric)
    return rec, prec, ap