# -*- coding: utf-8 -*-
# Copyright (c) Facebook, Inc. and its affiliates.

import logging
import numpy as np
import os
import itertools
import tempfile
import xml.etree.ElementTree as ET
from collections import OrderedDict, defaultdict
from functools import lru_cache
import torch

from detectron2.data import MetadataCatalog
from detectron2.utils import comm
from detectron2.utils.file_io import PathManager

from .evaluator import DatasetEvaluator

import pdb
from detectron2.structures import Boxes, pairwise_iou


class FoggyDetectionEvaluator(DatasetEvaluator):
    """
    Evaluate Pascal VOC style AP for Pascal VOC dataset.
    It contains a synchronization, therefore has to be called from all ranks.

    Note that the concept of AP can be implemented in different ways and may not
    produce identical results. This class mimics the implementation of the official
    Pascal VOC Matlab API, and should produce similar but not identical results to the
    official API.
    """

    def __init__(self, dataset_name):
        """
        Args:
            dataset_name (str): name of the dataset, e.g., "voc_2007_test"
        """
        self._dataset_name = dataset_name
        meta = MetadataCatalog.get(dataset_name)

        # Too many tiny files, download all to local for speed.
        annotation_dir_local = PathManager.get_local_path(
            os.path.join(meta.dirname, "Annotations/")
        )
        self._anno_file_template = os.path.join(annotation_dir_local, "{}.xml")
        self._image_set_path = os.path.join(meta.dirname, "ImageSets", "Main", meta.split + ".txt")
        self._class_names = meta.thing_classes
        assert meta.year in [2007, 2012], meta.year
        self._is_2007 = meta.year == 2007
        self._cpu_device = torch.device("cpu")
        self._logger = logging.getLogger(__name__)

    def reset(self):
        self._predictions = defaultdict(list)  # class name -> list of prediction strings

    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            image_id = input["image_id"]
            instances = output["instances"].to(self._cpu_device)
            boxes = instances.pred_boxes.tensor.numpy()
            scores = instances.scores.tolist()
            classes = instances.pred_classes.tolist()
            for box, score, cls in zip(boxes, scores, classes):
                xmin, ymin, xmax, ymax = box
                # The inverse of data loading logic in `datasets/pascal_voc.py`
                xmin += 1
                ymin += 1
                self._predictions[cls].append(
                    f"{image_id} {score:.3f} {xmin:.1f} {ymin:.1f} {xmax:.1f} {ymax:.1f}"
                )

    def evaluate(self):
        """
        Returns:
            dict: has a key "segm", whose value is a dict of "AP", "AP50", and "AP75".
        """
        all_predictions = comm.gather(self._predictions, dst=0)
        if not comm.is_main_process():
            return
        predictions = defaultdict(list)
        for predictions_per_rank in all_predictions:
            for clsid, lines in predictions_per_rank.items():
                predictions[clsid].extend(lines)
        del all_predictions

        # self._logger.info(
        #     "Evaluating {} using {} metric. "
        #     "Note that results do not use the official Matlab API.".format(
        #         self._dataset_name, 2007 if self._is_2007 else 2012
        #     )
        # )

        with tempfile.TemporaryDirectory(prefix="pascal_voc_eval_") as dirname:
            res_file_template = os.path.join(dirname, "{}.txt")

            aps = defaultdict(list)  # iou -> ap per class
            for cls_id, cls_name in enumerate(self._class_names):
                lines = predictions.get(cls_id, [""])

                with open(res_file_template.format(cls_name), "w") as f:
                    f.write("\n".join(lines))

                for thresh in range(50, 100, 5):
                    rec, prec, ap = voc_eval(
                        res_file_template,
                        self._anno_file_template,
                        self._image_set_path,
                        cls_name,
                        ovthresh=thresh / 100.0,
                        use_07_metric=self._is_2007,
                    )
                    aps[thresh].append(ap * 100)

        ret = OrderedDict()
        mAP = {iou: np.mean(x) for iou, x in aps.items()}
        ret["bbox"] = {"AP": np.mean(list(mAP.values())), "AP50": mAP[50], "AP75": mAP[75], "class-AP50": aps[50]}

        # pdb.set_trace()

        # foggy_results = list(itertools.chain(*[x["instances"] for x in predictions]))
        # if self._output_dir:
        #     file_path = os.path.join(self._output_dir, "foggy_instances_results.json")
        #     self._logger.info("Saving results to {}".format(file_path))
        #     with PathManager.open(file_path, "w") as f:
        #         f.write(json.dumps(foggy_results))
        #         f.flush()

        return ret


# ---------------------------------------------------------------
#  A.  CLASS-ONLY  evaluator
# ---------------------------------------------------------------
class FoggyClsOnlyEvaluator(FoggyDetectionEvaluator):
    """AP with perfect localisation: measures class correctness only."""
    def process(self, inputs, outputs):
        """
        For each prediction:
          1. Load the GT boxes from the Pascal-VOC XML (same util as base class).
          2. Snap the prediction to the best-IoU GT box (≥0.5).
        """
        for inp, out in zip(inputs, outputs):
            img_id = inp["image_id"]

            # ---------- load GT from XML ----------
            ann_path = self._anno_file_template.format(img_id)
            gt_objs  = parse_rec(ann_path)             # list of dicts

            if len(gt_objs) == 0:                      # no GT, skip
                continue
            gt_boxes = torch.as_tensor(
                [obj["bbox"] for obj in gt_objs], dtype=torch.float32
            )                                          # (G,4)
            gt_clsids = [
                self._class_names.index(obj["name"]) for obj in gt_objs
            ]

            # ---------- predictions ----------
            pred = out["instances"].to(self._cpu_device)
            if len(pred) == 0:
                continue
            p_boxes  = pred.pred_boxes.tensor          # (P,4)
            p_scores = pred.scores.tolist()
            p_clsids = pred.pred_classes.tolist()

            # ---------- IoU match ----------
            ious = pairwise_iou(Boxes(p_boxes), Boxes(gt_boxes)).numpy()
            best_gt = ious.argmax(1)
            best_iou = ious[np.arange(len(p_boxes)), best_gt]

            for p_idx, (score, cls_id) in enumerate(zip(p_scores, p_clsids)):
                if best_iou[p_idx] < 0.5:
                    continue                           # localisation too poor
                x1, y1, x2, y2 = gt_boxes[best_gt[p_idx]].numpy()
                x1 += 1; y1 += 1                       # VOC 1-based shift
                self._predictions[cls_id].append(
                    f"{img_id} {score:.3f} "
                    f"{x1:.1f} {y1:.1f} {x2:.1f} {y2:.1f}"
                )
    def evaluate(self):
        out = super().evaluate()
        if out is None: return
        return OrderedDict(cls_bbox = out["bbox"])   # unique key


# ---------------------------------------------------------------
#  B.  LOC-ONLY  evaluator
# ---------------------------------------------------------------
class FoggyLocOnlyEvaluator(FoggyDetectionEvaluator):
    """
    Computes AP50 after collapsing *all* categories into one.
    Ground-truth boxes are loaded from XML as usual; a detection counts as
    TP if it overlaps any unused GT box by IoU ≥ 0.5, regardless of class.
    """

    def __init__(self, dataset_name):
        super().__init__(dataset_name)          # keep paths, logger, etc.
        self._class_names = ["object"]          # dummy key for predictions

    # ---------- prediction collection (unchanged CPU fix) ----------
    def process(self, inputs, outputs):
        for inp, out in zip(inputs, outputs):
            img_id = inp["image_id"]
            inst   = out["instances"].to(self._cpu_device)

            boxes  = inst.pred_boxes.tensor.numpy()
            scores = inst.scores.tolist()

            for box, score in zip(boxes, scores):
                x1, y1, x2, y2 = box
                x1 += 1;  y1 += 1               # VOC 1-based shift
                self._predictions[0].append(
                    f"{img_id} {score:.3f} {x1:.1f} {y1:.1f} {x2:.1f} {y2:.1f}"
                )

    # ---------- custom AP50 computation ----------
    def evaluate(self):
        all_pred = comm.gather(self._predictions, dst=0)
        if not comm.is_main_process():
            return

        preds = []
        for d in all_pred:
            preds.extend(d[0])                  # key 0 holds all lines
        del all_pred

        # Parse predictions: img_id, score, box
        split = [l.strip().split(" ") for l in preds]
        if len(split) == 0:                     # no detections at all
            return OrderedDict(loc_bbox={"AP50": 0.0})

        img_ids = [s[0] for s in split]
        scores  = np.array([float(s[1]) for s in split])
        bboxes  = np.array([[float(z) for z in s[2:]] for s in split])

        # Sort by confidence (descending)
        order = scores.argsort()[::-1]
        img_ids = [img_ids[i] for i in order]
        bboxes  = bboxes[order]

        # ---------- load ALL GT boxes ----------
        with PathManager.open(self._image_set_path, "r") as f:
            img_list = [x.strip() for x in f.readlines()]

        gt_dict, npos = {}, 0
        for img in img_list:
            objs = parse_rec(self._anno_file_template.format(img))
            gbox = np.array([o["bbox"] for o in objs])
            det  = [False] * len(objs)
            gt_dict[img] = {"bbox": gbox, "det": det}
            npos += len(objs)

        # ---------- match loop ----------
        tp = np.zeros(len(bboxes))
        fp = np.zeros(len(bboxes))

        for i, (img, bb) in enumerate(zip(img_ids, bboxes)):
            gt = gt_dict[img]["bbox"].astype(float)
            if gt.size == 0:
                fp[i] = 1
                continue

            # IoU with every GT box in that image
            ixmin = np.maximum(gt[:, 0], bb[0])
            iymin = np.maximum(gt[:, 1], bb[1])
            ixmax = np.minimum(gt[:, 2], bb[2])
            iymax = np.minimum(gt[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
            ih = np.maximum(iymax - iymin + 1.0, 0.0)
            inters = iw * ih
            union = (
                (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0) +
                (gt[:, 2] - gt[:, 0] + 1.0) * (gt[:, 3] - gt[:, 1] + 1.0) -
                inters
            )
            iou = inters / union
            jmax = np.argmax(iou)
            if iou[jmax] >= 0.5 and not gt_dict[img]["det"][jmax]:
                tp[i] = 1
                gt_dict[img]["det"][jmax] = True
            else:
                fp[i] = 1

        # ---------- precision-recall and AP ----------
        fp = np.cumsum(fp)
        tp = np.cumsum(tp)
        rec = tp / float(max(npos, 1))
        prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)

        # VOC-style AP
        def voc_ap(rec, prec):
            mrec = np.concatenate(([0.0], rec, [1.0]))
            mpre = np.concatenate(([0.0], prec, [0.0]))
            for k in range(mpre.size - 1, 0, -1):
                mpre[k - 1] = np.maximum(mpre[k - 1], mpre[k])
            i = np.where(mrec[1:] != mrec[:-1])[0]
            return np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])

        ap50 = voc_ap(rec, prec) * 100.0
        return OrderedDict(loc_bbox={"AP50": ap50})

##############################################################################
#
# Below code is modified from
# https://github.com/rbgirshick/py-faster-rcnn/blob/master/lib/datasets/voc_eval.py
# --------------------------------------------------------
# Fast/er R-CNN
# Licensed under The MIT License [see LICENSE for details]
# Written by Bharath Hariharan
# --------------------------------------------------------

"""Python implementation of the PASCAL VOC devkit's AP evaluation code."""


@lru_cache(maxsize=None)
def parse_rec(filename):
    """Parse a PASCAL VOC xml file."""
    with PathManager.open(filename) as f:
        tree = ET.parse(f)
    objects = []
    for obj in tree.findall("object"):
        obj_struct = {}
        obj_struct["name"] = obj.find("name").text
        obj_struct["pose"] = obj.find("pose").text
        obj_struct["truncated"] = int(obj.find("truncated").text)
        obj_struct["difficult"] = int(obj.find("difficult").text)
        bbox = obj.find("bndbox")
        obj_struct["bbox"] = [
            int(bbox.find("xmin").text),
            int(bbox.find("ymin").text),
            int(bbox.find("xmax").text),
            int(bbox.find("ymax").text),
        ]
        objects.append(obj_struct)

    return objects


def voc_ap(rec, prec, use_07_metric=False):
    """Compute VOC AP given precision and recall. If use_07_metric is true, uses
    the VOC 07 11-point method (default:False).
    """
    if use_07_metric:
        # 11 point metric
        ap = 0.0
        for t in np.arange(0.0, 1.1, 0.1):
            if np.sum(rec >= t) == 0:
                p = 0
            else:
                p = np.max(prec[rec >= t])
            ap = ap + p / 11.0
    else:
        # correct AP calculation
        # first append sentinel values at the end
        mrec = np.concatenate(([0.0], rec, [1.0]))
        mpre = np.concatenate(([0.0], prec, [0.0]))

        # compute the precision envelope
        for i in range(mpre.size - 1, 0, -1):
            mpre[i - 1] = np.maximum(mpre[i - 1], mpre[i])

        # to calculate area under PR curve, look for points
        # where X axis (recall) changes value
        i = np.where(mrec[1:] != mrec[:-1])[0]

        # and sum (\Delta recall) * prec
        ap = np.sum((mrec[i + 1] - mrec[i]) * mpre[i + 1])
    return ap


def voc_eval(detpath, annopath, imagesetfile, classname, ovthresh=0.5, use_07_metric=False):
    """rec, prec, ap = voc_eval(detpath,
                                annopath,
                                imagesetfile,
                                classname,
                                [ovthresh],
                                [use_07_metric])

    Top level function that does the PASCAL VOC evaluation.

    detpath: Path to detections
        detpath.format(classname) should produce the detection results file.
    annopath: Path to annotations
        annopath.format(imagename) should be the xml annotations file.
    imagesetfile: Text file containing the list of images, one image per line.
    classname: Category name (duh)
    [ovthresh]: Overlap threshold (default = 0.5)
    [use_07_metric]: Whether to use VOC07's 11 point AP computation
        (default False)
    """
    # assumes detections are in detpath.format(classname)
    # assumes annotations are in annopath.format(imagename)
    # assumes imagesetfile is a text file with each line an image name

    # first load gt
    # read list of images
    with PathManager.open(imagesetfile, "r") as f:
        lines = f.readlines()
    imagenames = [x.strip() for x in lines]

    # load annots
    recs = {}
    for imagename in imagenames:
        recs[imagename] = parse_rec(annopath.format(imagename))

    # extract gt objects for this class
    class_recs = {}
    npos = 0
    for imagename in imagenames:
        R = [obj for obj in recs[imagename] if obj["name"] == classname]
        bbox = np.array([x["bbox"] for x in R])
        difficult = np.array([x["difficult"] for x in R]).astype(np.bool)
        # difficult = np.array([False for x in R]).astype(np.bool)  # treat all "difficult" as GT
        det = [False] * len(R)
        npos = npos + sum(~difficult)
        class_recs[imagename] = {"bbox": bbox, "difficult": difficult, "det": det}

    # read dets
    detfile = detpath.format(classname)
    with open(detfile, "r") as f:
        lines = f.readlines()

    splitlines = [x.strip().split(" ") for x in lines]
    image_ids = [x[0] for x in splitlines]
    confidence = np.array([float(x[1]) for x in splitlines])
    BB = np.array([[float(z) for z in x[2:]] for x in splitlines]).reshape(-1, 4)

    # sort by confidence
    sorted_ind = np.argsort(-confidence)
    BB = BB[sorted_ind, :]
    image_ids = [image_ids[x] for x in sorted_ind]

    # go down dets and mark TPs and FPs
    nd = len(image_ids)
    tp = np.zeros(nd)
    fp = np.zeros(nd)
    for d in range(nd):
        R = class_recs[image_ids[d]]
        bb = BB[d, :].astype(float)
        ovmax = -np.inf
        BBGT = R["bbox"].astype(float)

        if BBGT.size > 0:
            # compute overlaps
            # intersection
            ixmin = np.maximum(BBGT[:, 0], bb[0])
            iymin = np.maximum(BBGT[:, 1], bb[1])
            ixmax = np.minimum(BBGT[:, 2], bb[2])
            iymax = np.minimum(BBGT[:, 3], bb[3])
            iw = np.maximum(ixmax - ixmin + 1.0, 0.0)
            ih = np.maximum(iymax - iymin + 1.0, 0.0)
            inters = iw * ih

            # union
            uni = (
                (bb[2] - bb[0] + 1.0) * (bb[3] - bb[1] + 1.0)
                + (BBGT[:, 2] - BBGT[:, 0] + 1.0) * (BBGT[:, 3] - BBGT[:, 1] + 1.0)
                - inters
            )

            overlaps = inters / uni
            ovmax = np.max(overlaps)
            jmax = np.argmax(overlaps)

        if ovmax > ovthresh:
            if not R["difficult"][jmax]:
                if not R["det"][jmax]:
                    tp[d] = 1.0
                    R["det"][jmax] = 1
                else:
                    fp[d] = 1.0
        else:
            fp[d] = 1.0

    # compute precision recall
    fp = np.cumsum(fp)
    tp = np.cumsum(tp)
    rec = tp / float(npos)
    # avoid divide by zero in case the first detection matches a difficult
    # ground truth
    prec = tp / np.maximum(tp + fp, np.finfo(np.float64).eps)
    ap = voc_ap(rec, prec, use_07_metric)

    return rec, prec, ap