from copy import deepcopy
from typing import Any, Dict, List

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from .base import BaseEvaluation


class AttributionPercentage(BaseEvaluation):

    def __init__(self, pos_val: int = 255) -> None:
        super().__init__()
        self.pos_val = pos_val
        self._cache: List[float] = []
        self._results: List[float] = []

    @property
    def num_steps(self) -> int:
        return 1

    def reset_cache(self) -> None:
        self._cache.clear()

    def evaluate(
            self,
            model: nn.Module,
            img: torch.Tensor,
            label: torch.Tensor,
            attr_map: torch.Tensor,
            gt_mask: torch.Tensor,
            **kwargs: Any) -> None:
        # single image without batch dimension.
        assert img.dim() == 3 and img.shape[0] in (1, 3)
        assert attr_map.shape == gt_mask.shape and gt_mask.dim() == 2
        gt_mask = (gt_mask == self.pos_val)
        # sum of true positive attribution values
        sum_tp_attr = (attr_map * gt_mask).sum()
        total_attr = attr_map.sum() + 1e-8
        percentage = (sum_tp_attr / total_attr).item()

        self._cache.append(percentage)

    def summarize_step(self) -> None:
        self._results = deepcopy(self._cache)

    def summarize_total(self) -> Dict:
        total_result = {'percentages': self._results}
        return total_result

    def visualize_result(self, total_result: Dict, save_path: str) -> None:
        fig, ax = plt.subplots(1, 1, figsize=(8, 6))

        percentages = total_result['percentages']
        mean_percentage = np.mean(percentages)
        ax.set_title('Attribution Percentage')
        ax.set_xticks([1])
        ax.set_xticklabels([f'Mean Attr%: {mean_percentage:.4f}'])
        ax.violinplot(percentages)
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
