from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.bounding_box import BoxList
import json
import numpy as np
import os.path as osp
import os
from prettytable import PrettyTable

import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union

import maskrcnn_benchmark.utils.mdetr_dist as dist
#### The following loading utilities are imported from
#### https://github.com/BryanPlummer/flickr30k_entities/blob/68b3d6f12d1d710f96233f6bd2b6de799d6f4e5b/flickr30k_entities_utils.py
# Changelog:
#    - Added typing information
#    - Completed docstrings

def get_sentence_data(filename) -> List[Dict[str, Any]]:
    """
    Parses a sentence file from the Flickr30K Entities dataset

    input:
      filename - full file path to the sentence file to parse

    output:
      a list of dictionaries for each sentence with the following fields:
          sentence - the original sentence
          phrases - a list of dictionaries for each phrase with the
                    following fields:
                      phrase - the text of the annotated phrase
                      first_word_index - the position of the first word of
                                         the phrase in the sentence
                      phrase_id - an identifier for this phrase
                      phrase_type - a list of the coarse categories this
                                    phrase belongs to

    """
    with open(filename, "r") as f:
        sentences = f.read().split("\n")

    annotations = []
    for sentence in sentences:
        if not sentence:
            continue

        first_word = []
        phrases = []
        phrase_id = []
        phrase_type = []
        words = []
        current_phrase = []
        add_to_phrase = False
        for token in sentence.split():
            if add_to_phrase:
                if token[-1] == "]":
                    add_to_phrase = False
                    token = token[:-1]
                    current_phrase.append(token)
                    phrases.append(" ".join(current_phrase))
                    current_phrase = []
                else:
                    current_phrase.append(token)

                words.append(token)
            else:
                if token[0] == "[":
                    add_to_phrase = True
                    first_word.append(len(words))
                    parts = token.split("/")
                    phrase_id.append(parts[1][3:])
                    phrase_type.append(parts[2:])
                else:
                    words.append(token)

        sentence_data = {"sentence": " ".join(words), "phrases": []}
        for index, phrase, p_id, p_type in zip(first_word, phrases, phrase_id, phrase_type):
            sentence_data["phrases"].append(
                {"first_word_index": index, "phrase": phrase, "phrase_id": p_id, "phrase_type": p_type}
            )

        annotations.append(sentence_data)

    return annotations


def get_annotations(filename) -> Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]]:
    """
    Parses the xml files in the Flickr30K Entities dataset

    input:
      filename - full file path to the annotations file to parse

    output:
      dictionary with the following fields:
          scene - list of identifiers which were annotated as
                  pertaining to the whole scene
          nobox - list of identifiers which were annotated as
                  not being visible in the image
          boxes - a dictionary where the fields are identifiers
                  and the values are its list of boxes in the
                  [xmin ymin xmax ymax] format
          height - int representing the height of the image
          width - int representing the width of the image
          depth - int representing the depth of the image
    """
    tree = ET.parse(filename)
    root = tree.getroot()
    size_container = root.findall("size")[0]
    anno_info: Dict[str, Union[int, List[str], Dict[str, List[List[int]]]]] = {}
    all_boxes: Dict[str, List[List[int]]] = {}
    all_noboxes: List[str] = []
    all_scenes: List[str] = []
    for size_element in size_container:
        assert size_element.text
        anno_info[size_element.tag] = int(size_element.text)

    for object_container in root.findall("object"):
        for names in object_container.findall("name"):
            box_id = names.text
            assert box_id
            box_container = object_container.findall("bndbox")
            if len(box_container) > 0:
                if box_id not in all_boxes:
                    all_boxes[box_id] = []
                xmin = int(box_container[0].findall("xmin")[0].text)
                ymin = int(box_container[0].findall("ymin")[0].text)
                xmax = int(box_container[0].findall("xmax")[0].text)
                ymax = int(box_container[0].findall("ymax")[0].text)
                all_boxes[box_id].append([xmin, ymin, xmax, ymax])
            else:
                nobndbox = int(object_container.findall("nobndbox")[0].text)
                if nobndbox > 0:
                    all_noboxes.append(box_id)

                scene = int(object_container.findall("scene")[0].text)
                if scene > 0:
                    all_scenes.append(box_id)
    anno_info["boxes"] = all_boxes
    anno_info["nobox"] = all_noboxes
    anno_info["scene"] = all_scenes

    return anno_info


#### END of import from flickr30k_entities


#### Bounding box utilities imported from torchvision and converted to numpy
def box_area(boxes: np.array) -> np.array:
    """
    Computes the area of a set of bounding boxes, which are specified by its
    (x1, y1, x2, y2) coordinates.

    Args:
        boxes (Tensor[N, 4]): boxes for which the area will be computed. They
            are expected to be in (x1, y1, x2, y2) format with
            ``0 <= x1 < x2`` and ``0 <= y1 < y2``.

    Returns:
        area (Tensor[N]): area for each box
    """
    assert boxes.ndim == 2 and boxes.shape[-1] == 4
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


# implementation from https://github.com/kuangliu/torchcv/blob/master/torchcv/utils/box.py
# with slight modifications
def _box_inter_union(boxes1: np.array, boxes2: np.array) -> Tuple[np.array, np.array]:
    area1 = box_area(boxes1)
    area2 = box_area(boxes2)

    lt = np.maximum(boxes1[:, None, :2], boxes2[:, :2])  # [N,M,2]
    rb = np.minimum(boxes1[:, None, 2:], boxes2[:, 2:])  # [N,M,2]

    wh = (rb - lt).clip(min=0)  # [N,M,2]
    inter = wh[:, :, 0] * wh[:, :, 1]  # [N,M]

    union = area1[:, None] + area2 - inter

    return inter, union


def box_iou(boxes1: np.array, boxes2: np.array) -> np.array:
    """
    Return intersection-over-union (Jaccard index) of boxes.

    Both sets of boxes are expected to be in ``(x1, y1, x2, y2)`` format with
    ``0 <= x1 < x2`` and ``0 <= y1 < y2``.

    Args:
        boxes1 (Tensor[N, 4])
        boxes2 (Tensor[M, 4])

    Returns:
        iou (Tensor[N, M]): the NxM matrix containing the pairwise IoU values for every element in boxes1 and boxes2
    """
    inter, union = _box_inter_union(boxes1, boxes2)
    iou = inter / union
    return iou


#### End of import of box utilities

def _merge_boxes(boxes: List[List[int]]) -> List[List[int]]:
    """
    Return the boxes corresponding to the smallest enclosing box containing all the provided boxes
    The boxes are expected in [x1, y1, x2, y2] format
    """
    if len(boxes) == 1:
        return boxes

    np_boxes = np.asarray(boxes)

    return [[np_boxes[:, 0].min(), np_boxes[:, 1].min(), np_boxes[:, 2].max(), np_boxes[:, 3].max()]]


class RecallTracker:
    """ Utility class to track recall@k for various k, split by categories"""

    def __init__(self, topk: Sequence[int]):
        """
        Parameters:
           - topk : tuple of ints corresponding to the recalls being tracked (eg, recall@1, recall@10, ...)
        """

        self.total_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}
        self.positives_byk_bycat: Dict[int, Dict[str, int]] = {k: defaultdict(int) for k in topk}

    def add_positive(self, k: int, category: str):
        """Log a positive hit @k for given category"""
        if k not in self.total_byk_bycat:
            raise RuntimeError(f"{k} is not a valid recall threshold")
        self.total_byk_bycat[k][category] += 1
        self.positives_byk_bycat[k][category] += 1

    def add_negative(self, k: int, category: str):
        """Log a negative hit @k for given category"""
        if k not in self.total_byk_bycat:
            raise RuntimeError(f"{k} is not a valid recall threshold")
        self.total_byk_bycat[k][category] += 1

    def report(self) -> Dict[int, Dict[str, float]]:
        """Return a condensed report of the results as a dict of dict.
        report[k][cat] is the recall@k for the given category
        """
        report: Dict[int, Dict[str, float]] = {}
        for k in self.total_byk_bycat:
            assert k in self.positives_byk_bycat
            report[k] = {
                cat: self.positives_byk_bycat[k][cat] / self.total_byk_bycat[k][cat] for cat in self.total_byk_bycat[k]
            }
        return report


class Flickr30kEntitiesRecallEvaluator:
    def __init__(
            self,
            flickr_path: str,
            subset: str = "test",
            topk: Sequence[int] = (1, 5, 10, -1),
            iou_thresh: float = 0.5,
            merge_boxes: bool = False,
            verbose: bool = True,
    ):
        assert subset in ["train", "test", "val"], f"Wrong flickr subset {subset}"

        self.topk = topk
        self.iou_thresh = iou_thresh

        flickr_path = Path(flickr_path)

        # We load the image ids corresponding to the current subset
        with open(flickr_path / f"{subset}.txt") as file_d:
            self.img_ids = [line.strip() for line in file_d]

        if verbose:
            print(f"Flickr subset contains {len(self.img_ids)} images")

        # Read the box annotations for all the images
        self.imgid2boxes: Dict[str, Dict[str, List[List[int]]]] = {}

        if verbose:
            print("Loading annotations...")

        for img_id in self.img_ids:
            anno_info = get_annotations(flickr_path / "Annotations" / f"{img_id}.xml")["boxes"]
            if merge_boxes:
                merged = {}
                for phrase_id, boxes in anno_info.items():
                    merged[phrase_id] = _merge_boxes(boxes)
                anno_info = merged
            self.imgid2boxes[img_id] = anno_info

        # Read the sentences annotations
        self.imgid2sentences: Dict[str, List[List[Optional[Dict]]]] = {}

        if verbose:
            print("Loading annotations...")

        self.all_ids: List[str] = []
        tot_phrases = 0
        for img_id in self.img_ids:
            sentence_info = get_sentence_data(flickr_path / "Sentences" / f"{img_id}.txt")
            self.imgid2sentences[img_id] = [None for _ in range(len(sentence_info))]

            # Some phrases don't have boxes, we filter them.
            for sent_id, sentence in enumerate(sentence_info):
                phrases = [phrase for phrase in sentence["phrases"] if phrase["phrase_id"] in self.imgid2boxes[img_id]]
                if len(phrases) > 0:
                    self.imgid2sentences[img_id][sent_id] = phrases
                tot_phrases += len(phrases)

            self.all_ids += [
                f"{img_id}_{k}" for k in range(len(sentence_info)) if self.imgid2sentences[img_id][k] is not None
            ]

        if verbose:
            print(f"There are {tot_phrases} phrases in {len(self.all_ids)} sentences to evaluate")

    def evaluate(self, predictions: List[Dict]):
        evaluated_ids = set()

        recall_tracker = RecallTracker(self.topk)

        for pred in predictions:
            cur_id = f"{pred['image_id']}_{pred['sentence_id']}"
            if cur_id in evaluated_ids:
                print(
                    "Warning, multiple predictions found for sentence"
                    f"{pred['sentence_id']} in image {pred['image_id']}"
                )
                continue

            # Skip the sentences with no valid phrase
            if cur_id not in self.all_ids:
                if len(pred["boxes"]) != 0:
                    print(
                        f"Warning, in image {pred['image_id']} we were not expecting predictions "
                        f"for sentence {pred['sentence_id']}. Ignoring them."
                    )
                continue

            evaluated_ids.add(cur_id)

            pred_boxes = pred["boxes"]
            if str(pred["image_id"]) not in self.imgid2sentences:
                raise RuntimeError(f"Unknown image id {pred['image_id']}")
            if not 0 <= int(pred["sentence_id"]) < len(self.imgid2sentences[str(pred["image_id"])]):
                raise RuntimeError(f"Unknown sentence id {pred['sentence_id']}" f" in image {pred['image_id']}")
            target_sentence = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]

            phrases = self.imgid2sentences[str(pred["image_id"])][int(pred["sentence_id"])]
            if len(pred_boxes) != len(phrases):
                raise RuntimeError(
                    f"Error, got {len(pred_boxes)} predictions, expected {len(phrases)} "
                    f"for sentence {pred['sentence_id']} in image {pred['image_id']}"
                )

            for cur_boxes, phrase in zip(pred_boxes, phrases):
                target_boxes = self.imgid2boxes[str(pred["image_id"])][phrase["phrase_id"]]

                ious = box_iou(np.asarray(cur_boxes), np.asarray(target_boxes))
                for k in self.topk:
                    maxi = 0
                    if k == -1:
                        maxi = ious.max()
                    else:
                        assert k > 0
                        maxi = ious[:k].max()
                    if maxi >= self.iou_thresh:
                        recall_tracker.add_positive(k, "all")
                        for phrase_type in phrase["phrase_type"]:
                            recall_tracker.add_positive(k, phrase_type)
                    else:
                        recall_tracker.add_negative(k, "all")
                        for phrase_type in phrase["phrase_type"]:
                            recall_tracker.add_negative(k, phrase_type)

        if len(evaluated_ids) != len(self.all_ids):
            print("ERROR, the number of evaluated sentence doesn't match. Missing predictions:")
            un_processed = set(self.all_ids) - evaluated_ids
            for missing in un_processed:
                img_id, sent_id = missing.split("_")
                print(f"\t sentence {sent_id} in image {img_id}")
            raise RuntimeError("Missing predictions")

        return recall_tracker.report()


class FlickrEvaluator(object):
    def __init__(
            self,
            flickr_path,
            subset,
            top_k=(1, 5, 10, -1),
            iou_thresh=0.5,
            merge_boxes=False,
    ):
        assert isinstance(top_k, (list, tuple))

        self.evaluator = Flickr30kEntitiesRecallEvaluator(
            flickr_path, subset=subset, topk=top_k, iou_thresh=iou_thresh, merge_boxes=merge_boxes, verbose=False
        )
        self.predictions = []
        self.results = None

    def accumulate(self):
        pass

    def update(self, predictions):
        self.predictions += predictions

    def synchronize_between_processes(self):
        all_predictions = dist.all_gather(self.predictions)
        self.predictions = sum(all_predictions, [])

    def summarize(self):
        if dist.is_main_process():
            self.results = self.evaluator.evaluate(self.predictions)
            table = PrettyTable()
            all_cat = sorted(list(self.results.values())[0].keys())
            table.field_names = ["Recall@k"] + all_cat

            score = {}
            for k, v in self.results.items():
                cur_results = [v[cat] for cat in all_cat]
                header = "Upper_bound" if k == -1 else f"Recall@{k}"

                for cat in all_cat:
                    score[f"{header}_{cat}"] = v[cat]
                table.add_row([header] + cur_results)

            print(table)

            return score

        return None, None
