"""
Explanation behavior using an explicand representation's self-weighted importance, as
proposed by Crabbe et al. Label-Free Explainability for Unsupervised Models
(https://arxiv.org/abs/2203.01928) and by Wickstrom et al. RELAX: Representation
Learning Explainability (https://arxiv.org/abs/2112.10161).
"""

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from cl_explain.explanations.explanation_base import ExplanationBase


class WeightedScore(ExplanationBase):
    """
    Module class for label-free feature importance proposed by Crabbe et al. 2022
    (without normalization) or RELAX by Wickstrom et al. 2022 (with normalization).

    Args:
    ----
        encoder: Encoder module to explain.
        normalize: Whether to normalize dot product similarity by product of vector
            norms (that is, whether to use cosine similarity).
    """

    def __init__(self, encoder: nn.Module, normalize: bool) -> None:
        super().__init__(encoder=encoder)
        self.normalize = normalize
        self.weight = None

    def forward(self, explicand: torch.Tensor) -> torch.Tensor:
        return self._compute_pointwise_similarity(explicand)

    def generate_weight(self, x: torch.Tensor) -> None:
        """Use inputs to generate and update weight for computing weighted score."""
        self.weight = self.encoder(x).detach()

    def _compute_pointwise_similarity(self, explicand: torch.Tensor) -> torch.Tensor:
        """Compute pointwise similarities between explicands and weight vectors."""
        assert self.weight is not None, (
            "WeightedScore.weight needs to be generated by calling the method"
            " WeightedScore.generate_weight"
        )
        explicand_rep = self.encoder(explicand)
        similarity = (self.weight * explicand_rep).sum(dim=-1)
        if self.normalize:
            similarity /= self.weight.norm(dim=-1) * explicand_rep.norm(dim=-1)
        return similarity

    def _compute_pairwise_similarity(
        self,
        explicand: torch.Tensor,
        rep_dataloader: DataLoader,
        rep_data_size: int,
    ) -> torch.Tensor:
        """Compute pairwise similarities between explicands and some representations."""
        explicand_rep = self.encoder(explicand)
        similarity = 0
        for (x,) in rep_dataloader:
            x = x.to(explicand_rep.device)
            if self.normalize:
                x = self._compute_cosine_similarity(explicand_rep, x)
            else:
                x = self._compute_dot_product(explicand_rep, x)
            x = x.sum(dim=1)
            similarity += x
        return similarity / rep_data_size  # Average over number of comparisons.
