from abc import ABCMeta, abstractmethod
from typing import Any, Dict

import torch
import torch.nn as nn


class BaseEvaluation(metaclass=ABCMeta):

    def __init__(self) -> None:
        self._current_step = 0

    def increment_step(self) -> None:
        self._current_step += 1
        if self._current_step > self.num_steps:
            raise ValueError(
                f'Current step: {self._current_step} has exceeded the '
                f'maximal number of steps: {self.num_steps}')

    @property
    def current_step(self) -> int:
        return self._current_step

    @property
    @abstractmethod
    def num_steps(self) -> int:
        pass

    @abstractmethod
    def reset_cache(self) -> None:
        pass

    @abstractmethod
    def evaluate(
            self,
            model: nn.Module,
            img: torch.Tensor,
            label: torch.Tensor,
            attr_map: torch.Tensor,
            gt_mask: torch.Tensor,
            **kwargs: Any) -> None:
        pass

    @abstractmethod
    def summarize_step(self) -> None:
        pass

    @abstractmethod
    def summarize_total(self) -> Dict:
        pass

    @abstractmethod
    def visualize_result(self, total_result: Dict, save_path: str) -> None:
        pass
