from copy import deepcopy
from typing import Any, Dict, List, Optional, Tuple

import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn

from .base import BaseEvaluation


class GroundTruthEvaluation(BaseEvaluation):

    def __init__(self, pos_val: int, neg_val: Optional[int], normalize: bool = True) -> None:
        super().__init__()
        self.pos_val = pos_val
        self.neg_val = neg_val
        self.normalize = normalize

        self._cache: List[Tuple[float, ...]] = []
        self._results: List[Tuple[float, ...]] = []

    @property
    def num_steps(self) -> int:
        return 1

    def reset_cache(self) -> None:
        self._cache.clear()

    def evaluate(
            self,
            model: nn.Module,
            img: torch.Tensor,
            label: torch.Tensor,
            attr_map: torch.Tensor,
            gt_mask: torch.Tensor,
            **kwargs: Any) -> None:
        # single image without batch dimension.
        assert img.dim() == 3 and img.shape[0] in (1, 3)
        assert attr_map.shape == gt_mask.shape and gt_mask.dim() == 2
        if self.normalize:
            # normalize the attribution map to [-1, 1]
            abs_pos_max_val = torch.max(attr_map)
            abs_neg_min_val = torch.min(attr_map).abs()
            attr_map[attr_map >= 0] /= (abs_pos_max_val + 1e-8)
            attr_map[attr_map < 0] /= (abs_neg_min_val + 1e-8)

        pos_gt_mask = (gt_mask == self.pos_val)
        # compute the "positive" precision and recall
        pos_attr_map = torch.clamp(attr_map, min=0.0)
        pos_prec, pos_recall = self.compute_precision_and_recall(pos_attr_map, pos_gt_mask)

        if self.neg_val is not None:
            neg_gt_mask = (gt_mask == self.neg_val)

            # compute "overall" precision and recall
            overall_gt_mask = torch.logical_or(pos_gt_mask, neg_gt_mask)
            overall_attr_map = torch.abs(attr_map)
            overall_prec, overall_recall = self.compute_precision_and_recall(overall_attr_map, overall_gt_mask)
            single_result = (
                overall_prec,
                overall_recall,
                pos_prec,
                pos_recall,
            )
            if torch.min(attr_map) < 0:
                # compute the "negative" precision and recall
                neg_attr_map = torch.abs(torch.clamp(attr_map, max=0.0))
                neg_prec, neg_recall = self.compute_precision_and_recall(neg_attr_map, neg_gt_mask)
                single_result += (neg_prec, neg_recall)
        else:
            overall_gt_mask = pos_gt_mask
            overall_attr_map = torch.abs(attr_map)
            overall_prec, overall_recall = self.compute_precision_and_recall(overall_attr_map, overall_gt_mask)
            single_result = (
                overall_prec,
                overall_recall,
                pos_prec,
                pos_recall,
            )

        self._cache.append(single_result)

    @staticmethod
    def compute_precision_and_recall(attr_map: torch.Tensor, binary_gt_mask) -> Tuple[float, float]:
        # sum of true positive attribution values
        sum_tp_attr = (attr_map * binary_gt_mask).sum()
        total_attr = attr_map.sum() + 1e-8
        total_gt_mask = binary_gt_mask.sum() + 1e-8

        precision = (sum_tp_attr / total_attr).item()
        recall = (sum_tp_attr / total_gt_mask).item()
        return precision, recall

    def summarize_step(self) -> None:
        # For the case in which some single results contain different number of elements.
        # For instance, some only contain overall and positive metrics, while other contain
        # overall, positive, and negative metrics
        possible_lengths = set([len(result) for result in self._cache])
        if len(possible_lengths) > 1:
            max_single_result_length = max(possible_lengths)
            for i, single_result in enumerate(self._cache):
                single_result_length = len(single_result)
                if single_result_length < max_single_result_length:
                    single_result += tuple([0 for _ in range(max_single_result_length - single_result_length)])
                    self._cache[i] = single_result
        self._results = deepcopy(self._cache)

    def summarize_total(self) -> Dict:
        results = list(map(list, zip(*self._results)))
        if len(results) == 4:
            overall_prec, overall_recall, pos_prec, pos_recall = results
            total_result = {
                'overall_precision': overall_prec,
                'overall_recall': overall_recall,
                'positive_precision': pos_prec,
                'positive_recall': pos_recall
            }
        elif len(results) == 6:
            overall_prec, overall_recall, pos_prec, pos_recall, neg_prec, neg_recall = results
            total_result = {
                'overall_precision': overall_prec,
                'overall_recall': overall_recall,
                'positive_precision': pos_prec,
                'positive_recall': pos_recall,
                'negative_precision': neg_prec,
                'negative_recall': neg_recall
            }
        else:
            raise ValueError(f'results have invalid length {len(results)}, it should be either 4 or 6.')
        return total_result

    def visualize_result(self, total_result: Dict, save_path: str) -> None:
        if len(total_result) == 4:
            fig, axes = plt.subplots(1, 4, figsize=(12, 9))
        else:
            fig, axes = plt.subplots(1, 6, figsize=(20, 5))

        for ax, (title, result) in zip(axes, total_result.items()):
            mean_result = np.mean(result)
            ax.set_title(title)
            ax.set_xticks([1])
            ax.set_xticklabels([f'Mean {title}: {mean_result:.4f}'])
            ax.violinplot(result)
        plt.tight_layout()
        plt.savefig(save_path, bbox_inches='tight')
