"""

TextAttack Constraint Class
=====================================
"""

from abc import ABC, abstractmethod

import textattack
from textattack.shared.utils import ReprMixin


class Constraint(ReprMixin, ABC):
    """An abstract class that represents constraints on adversial text
    examples. Constraints evaluate whether transformations from a
    ``AttackedText`` to another ``AttackedText`` meet certain conditions.

    Args:
        compare_against_original (bool): If `True`, the reference text should be the original text under attack.
            If `False`, the reference text is the most recent text from which the transformed text was generated.
            All constraints must have this attribute.
    """

    def __init__(self, compare_against_original):
        self.compare_against_original = compare_against_original

    def call_many(self, transformed_texts, reference_text):
        """Filters ``transformed_texts`` based on which transformations fulfill
        the constraint. First checks compatibility with latest
        ``Transformation``, then calls ``_check_constraint_many``

        Args:
            transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``'s.
            reference_text (AttackedText): The ``AttackedText`` to compare against.
        """
        incompatible_transformed_texts = []
        compatible_transformed_texts = []
        for transformed_text in transformed_texts:
            try:
                if self.check_compatibility(
                    transformed_text.attack_attrs["last_transformation"]
                ):
                    compatible_transformed_texts.append(transformed_text)
                else:
                    incompatible_transformed_texts.append(transformed_text)
            except KeyError:
                raise KeyError(
                    "transformed_text must have `last_transformation` attack_attr to apply constraint"
                )
        filtered_texts = self._check_constraint_many(
            compatible_transformed_texts, reference_text
        )
        return list(filtered_texts) + incompatible_transformed_texts

    def _check_constraint_many(self, transformed_texts, reference_text):
        """Filters ``transformed_texts`` based on which transformations fulfill
        the constraint. Calls ``check_constraint``

        Args:
            transformed_texts (list[AttackedText]): The candidate transformed ``AttackedText``
            reference_texts (AttackedText): The ``AttackedText`` to compare against.
        """
        return [
            transformed_text
            for transformed_text in transformed_texts
            if self._check_constraint(transformed_text, reference_text)
        ]

    def __call__(self, transformed_text, reference_text):
        """Returns True if the constraint is fulfilled, False otherwise. First
        checks compatibility with latest ``Transformation``, then calls
        ``_check_constraint``

        Args:
            transformed_text (AttackedText): The candidate transformed ``AttackedText``.
            reference_text (AttackedText): The ``AttackedText`` to compare against.
        """
        if not isinstance(transformed_text, textattack.shared.AttackedText):
            raise TypeError("transformed_text must be of type AttackedText")
        if not isinstance(reference_text, textattack.shared.AttackedText):
            raise TypeError("reference_text must be of type AttackedText")

        try:
            if not self.check_compatibility(
                transformed_text.attack_attrs["last_transformation"]
            ):
                return True
        except KeyError:
            raise KeyError(
                "`transformed_text` must have `last_transformation` attack_attr to apply constraint."
            )
        return self._check_constraint(transformed_text, reference_text)

    @abstractmethod
    def _check_constraint(self, transformed_text, reference_text):
        """Returns True if the constraint is fulfilled, False otherwise. Must
        be overridden by the specific constraint.

        Args:
            transformed_text: The candidate transformed ``AttackedText``.
            reference_text (AttackedText): The ``AttackedText`` to compare against.
        """
        raise NotImplementedError()

    def check_compatibility(self, transformation):
        """Checks if this constraint is compatible with the given
        transformation. For example, the ``WordEmbeddingDistance`` constraint
        compares the embedding of the word inserted with that of the word
        deleted. Therefore it can only be applied in the case of word swaps,
        and not for transformations which involve only one of insertion or
        deletion.

        Args:
            transformation: The ``Transformation`` to check compatibility with.
        """
        return True

    def extra_repr_keys(self):
        """Set the extra representation of the constraint using these keys.

        To print customized extra information, you should reimplement
        this method in your own constraint. Both single-line and multi-
        line strings are acceptable.
        """
        return ["compare_against_original"]
