"""
AttackResult Class
======================

"""

from abc import ABC

from langdetect import detect

from textattack.goal_function_results import GoalFunctionResult
from textattack.shared import utils


class AttackResult(ABC):
    """Result of an Attack run on a single (output, text_input) pair.

    Args:
        original_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
            Result of the goal function applied to the original text
        perturbed_result (:class:`~textattack.goal_function_results.GoalFunctionResult`):
            Result of the goal function applied to the perturbed text. May or may not have been successful.
    """

    def __init__(self, original_result, perturbed_result):
        if original_result is None:
            raise ValueError("Attack original result cannot be None")
        elif not isinstance(original_result, GoalFunctionResult):
            raise TypeError(f"Invalid original goal function result: {original_result}")
        if perturbed_result is None:
            raise ValueError("Attack perturbed result cannot be None")
        elif not isinstance(perturbed_result, GoalFunctionResult):
            raise TypeError(
                f"Invalid perturbed goal function result: {perturbed_result}"
            )

        self.original_result = original_result
        self.perturbed_result = perturbed_result
        self.num_queries = perturbed_result.num_queries

        # We don't want the AttackedText attributes sticking around clogging up
        # space on our devices. Delete them here, if they're still present,
        # because we won't need them anymore anyway.
        self.original_result.attacked_text.free_memory()
        self.perturbed_result.attacked_text.free_memory()

    def original_text(self, color_method=None):
        """Returns the text portion of `self.original_result`.

        Helper method.
        """
        return self.original_result.attacked_text.printable_text(
            key_color=("bold", "underline"), key_color_method=color_method
        )

    def perturbed_text(self, color_method=None):
        """Returns the text portion of `self.perturbed_result`.

        Helper method.
        """
        return self.perturbed_result.attacked_text.printable_text(
            key_color=("bold", "underline"), key_color_method=color_method
        )

    def str_lines(self, color_method=None):
        """A list of the lines to be printed for this result's string
        representation."""
        lines = [self.goal_function_result_str(color_method=color_method)]
        lines.extend(self.diff_color(color_method))
        return lines

    def __str__(self, color_method=None):
        return "\n\n".join(self.str_lines(color_method=color_method))

    def goal_function_result_str(self, color_method=None):
        """Returns a string illustrating the results of the goal function."""
        orig_colored = self.original_result.get_colored_output(color_method)
        pert_colored = self.perturbed_result.get_colored_output(color_method)
        return orig_colored + " --> " + pert_colored

    def diff_color(self, color_method=None):
        """Highlights the difference between two texts using color.

        Has to account for deletions and insertions from original text to
        perturbed. Relies on the index map stored in
        ``self.original_result.attacked_text.attack_attrs["original_index_map"]``.
        """
        t1 = self.original_result.attacked_text
        t2 = self.perturbed_result.attacked_text

        if detect(t1.text) == "zh-cn" or detect(t1.text) == "ko":
            return t1.printable_text(), t2.printable_text()

        if color_method is None:
            return t1.printable_text(), t2.printable_text()

        color_1 = self.original_result.get_text_color_input()
        color_2 = self.perturbed_result.get_text_color_perturbed()

        # iterate through and count equal/unequal words
        words_1_idxs = []
        t2_equal_idxs = set()
        original_index_map = t2.attack_attrs["original_index_map"]
        for t1_idx, t2_idx in enumerate(original_index_map):
            if t2_idx == -1:
                # add words in t1 that are not in t2
                words_1_idxs.append(t1_idx)
            else:
                w1 = t1.words[t1_idx]
                w2 = t2.words[t2_idx]
                if w1 == w2:
                    t2_equal_idxs.add(t2_idx)
                else:
                    words_1_idxs.append(t1_idx)

        # words to color in t2 are all the words that didn't have an equal,
        # mapped word in t1
        words_2_idxs = list(sorted(set(range(t2.num_words)) - t2_equal_idxs))

        # make lists of colored words
        words_1 = [t1.words[i] for i in words_1_idxs]
        words_1 = [utils.color_text(w, color_1, color_method) for w in words_1]
        words_2 = [t2.words[i] for i in words_2_idxs]
        words_2 = [utils.color_text(w, color_2, color_method) for w in words_2]

        t1 = self.original_result.attacked_text.replace_words_at_indices(
            words_1_idxs, words_1
        )
        t2 = self.perturbed_result.attacked_text.replace_words_at_indices(
            words_2_idxs, words_2
        )

        key_color = ("bold", "underline")
        return (
            t1.printable_text(key_color=key_color, key_color_method=color_method),
            t2.printable_text(key_color=key_color, key_color_method=color_method),
        )
