import contextlib
import os
import io
import numpy as np
import torch
import json
import copy
import itertools
import logging
from collections import OrderedDict, defaultdict
from .omnilabeleval import OmniLabelEval
from .omnilabel import OmniLabel

from detectron2.structures import Boxes, BoxMode, pairwise_iou
from detectron2.data.datasets.coco import convert_to_coco_json
from detectron2.evaluation import DatasetEvaluator
from detectron2.utils.file_io import PathManager
import detectron2.utils.comm as comm
from detectron2.data import MetadataCatalog

class OmnilabelgEvaluator(DatasetEvaluator):
    
    def __init__(
        self,
        dataset_name,
        tasks=None,
        distributed=True,
        output_dir=None,
        *,
        max_dets_per_image=None,
    ):
        
        self._logger = logging.getLogger(__name__)
        self.dataset_name = dataset_name

        self._distributed = distributed
        self._output_dir = output_dir

        self._cpu_device = torch.device("cpu")

        self._metadata = MetadataCatalog.get(dataset_name)
        if not hasattr(self._metadata, "json_file"):
            if output_dir is None:
                raise ValueError(
                    "output_dir must be provided to COCOEvaluator "
                    "for datasets not in COCO format."
                )
            self._logger.info(f"Trying to convert '{dataset_name}' to COCO format ...")

            cache_path = os.path.join(output_dir, f"{dataset_name}_coco_format.json")
            self._metadata.json_file = cache_path
            convert_to_coco_json(dataset_name, cache_path, allow_cached=True)

        self.json_file = PathManager.get_local_path(self._metadata.json_file)

    def reset(self):
        self._predictions = []

    def process(self, inputs, outputs):
        for input, output in zip(inputs, outputs):
            prediction = {"image_id": input["image_id"],}
            instances = output["instances"].to(self._cpu_device)
            prediction["instances"] = instances_to_coco_json(
                instances, input["image_id"], input["description_ids"])
            self._predictions.append(prediction)

    def evaluate(self, img_ids=None):
        if self._distributed:
            comm.synchronize()
            predictions = comm.gather(self._predictions, dst=0)
            predictions = list(itertools.chain(*predictions))

            if not comm.is_main_process():
                return {}
        else:
            predictions = self._predictions

        if len(predictions) == 0:
            self._logger.warning("[COCOEvaluator] Did not receive valid predictions.")
            return {}

        if self._output_dir:
            PathManager.mkdirs(self._output_dir)
            file_path = os.path.join(self._output_dir, "instances_predictions.pth")
            with PathManager.open(file_path, "wb") as f:
                torch.save(predictions, f)

        self._results = OrderedDict()
        self._eval_predictions(predictions, img_ids=img_ids)
        # Copy so the caller can do whatever with results
        return copy.deepcopy(self._results)

    def _tasks_from_predictions(self):
        return ("bbox",)
    
    def _eval_predictions(self, predictions, img_ids=None): 
        """
        Evaluate predictions. Fill self._results with the metrics of the tasks.
        """
        self._logger.info("Preparing results for COCO format ...")
        coco_results = list(itertools.chain(*[x["instances"] for x in predictions]))
        """
        if self._output_dir and is_main_process():
            file_path = os.path.join(
                self._output_dir, "{}_instances_results.json".format(self.dataset_name)
            )
            self._logger.info("Saving results to {}".format(file_path))
            with PathManager.open(file_path, "w") as f:
                f.write(json.dumps(coco_results))
                f.flush()

        self._logger.info("Evaluating predictions with Omnilabel API...")
        """
        
        assert len(coco_results) > 0
        olgt = OmniLabel(path_json=self.json_file)
        oldt = olgt.load_res(result_json=coco_results)
        print(len(predictions))
        ol_eval = OmniLabelEval(gt=olgt, dt=oldt)
        ol_eval.evaluate()
        ol_eval.accumulate()
        ret, coco_eval = ol_eval.summarize(verbose=True)
        self._logger.info(ret)
        #bbox_odinw_AP = {"AP": [], "AP50": [], "AP75": [], "APs": [], "APm": [], "APl": []}
        self._results["bbox_odinw"] = coco_eval
        

def instances_to_coco_json(instances, img_id, desc_id):
    """
    Dump an "Instances" object to a COCO-format json that's used for evaluation.

    Args:
        instances (Instances):
        img_id (int): the image id

    Returns:
        list[dict]: list of json annotations in COCO format.
    """
    num_instance = len(instances)
    if num_instance == 0:
        return []

    boxes = instances.pred_boxes.tensor.numpy()
    boxes = BoxMode.convert(boxes, BoxMode.XYXY_ABS, BoxMode.XYWH_ABS)
    boxes = boxes.tolist()
    scores = instances.scores.tolist()
    classes = instances.pred_classes.tolist()
    results = []
    for k in range(num_instance):
        result = {
            "image_id": img_id,
            "description_ids": desc_id,
            "category_id": classes[k],
            "bbox": boxes[k],
            "scores": [scores[k]],
        }
        results.append(result)
    return results