# Copyright (c) Facebook, Inc. and its affiliates.

# pyre-unsafe

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Optional

from detectron2.structures import Instances

ModelOutput = Dict[str, Any]
SampledData = Dict[str, Any]


@dataclass
class _Sampler:
    """
    Sampler registry entry that contains:
     - src (str): source field to sample from (deleted after sampling)
     - dst (Optional[str]): destination field to sample to, if not None
     - func (Optional[Callable: Any -> Any]): function that performs sampling,
         if None, reference copy is performed
    """

    src: str
    dst: Optional[str]
    func: Optional[Callable[[Any], Any]]


class PredictionToGroundTruthSampler:
    """
    Sampler implementation that converts predictions to GT using registered
    samplers for different fields of `Instances`.
    """

    def __init__(self, dataset_name: str = ""):
        self.dataset_name = dataset_name
        self._samplers = {}
        self.register_sampler("pred_boxes", "gt_boxes", None)
        self.register_sampler("pred_classes", "gt_classes", None)
        # delete scores
        self.register_sampler("scores")

    def __call__(self, model_output: List[ModelOutput]) -> List[SampledData]:
        """
        Transform model output into ground truth data through sampling

        Args:
          model_output (Dict[str, Any]): model output
        Returns:
          Dict[str, Any]: sampled data
        """
        for model_output_i in model_output:
            instances: Instances = model_output_i["instances"]
            # transform data in each field
            for _, sampler in self._samplers.items():
                if not instances.has(sampler.src) or sampler.dst is None:
                    continue
                if sampler.func is None:
                    instances.set(sampler.dst, instances.get(sampler.src))
                else:
                    instances.set(sampler.dst, sampler.func(instances))
            # delete model output data that was transformed
            for _, sampler in self._samplers.items():
                if sampler.src != sampler.dst and instances.has(sampler.src):
                    instances.remove(sampler.src)
            model_output_i["dataset"] = self.dataset_name
        return model_output

    def register_sampler(
        self,
        prediction_attr: str,
        gt_attr: Optional[str] = None,
        func: Optional[Callable[[Any], Any]] = None,
    ):
        """
        Register sampler for a field

        Args:
          prediction_attr (str): field to replace with a sampled value
          gt_attr (Optional[str]): field to store the sampled value to, if not None
          func (Optional[Callable: Any -> Any]): sampler function
        """
        self._samplers[(prediction_attr, gt_attr)] = _Sampler(
            src=prediction_attr, dst=gt_attr, func=func
        )

    def remove_sampler(
        self,
        prediction_attr: str,
        gt_attr: Optional[str] = None,
    ):
        """
        Remove sampler for a field

        Args:
          prediction_attr (str): field to replace with a sampled value
          gt_attr (Optional[str]): field to store the sampled value to, if not None
        """
        assert (prediction_attr, gt_attr) in self._samplers
        del self._samplers[(prediction_attr, gt_attr)]
