"""

ClassificationGoalFunctionResult Class
========================================

"""

import torch

import textattack
from textattack.shared import utils

from .goal_function_result import GoalFunctionResult


class ClassificationGoalFunctionResult(GoalFunctionResult):
    """Represents the result of a classification goal function."""

    def __init__(
        self,
        attacked_text,
        raw_output,
        output,
        goal_status,
        score,
        num_queries,
        ground_truth_output,
    ):

        super().__init__(
            attacked_text,
            raw_output,
            output,
            goal_status,
            score,
            num_queries,
            ground_truth_output,
            goal_function_result_type="Classification",
        )

    @property
    def _processed_output(self):
        """Takes a model output (like `1`) and returns the class labeled output
        (like `positive`), if possible.

        Also returns the associated color.
        """
        output_label = self.raw_output.argmax()
        if self.attacked_text.attack_attrs.get("label_names") is not None:
            output = self.attacked_text.attack_attrs["label_names"][self.output]
            output = textattack.shared.utils.process_label_name(output)
            color = textattack.shared.utils.color_from_output(output, output_label)
            return output, color
        else:
            color = textattack.shared.utils.color_from_label(output_label)
            return output_label, color

    def get_text_color_input(self):
        """A string representing the color this result's changed portion should
        be if it represents the original input."""
        _, color = self._processed_output
        return color

    def get_text_color_perturbed(self):
        """A string representing the color this result's changed portion should
        be if it represents the perturbed input."""
        _, color = self._processed_output
        return color

    def get_colored_output(self, color_method=None):
        """Returns a string representation of this result's output, colored
        according to `color_method`."""
        output_label = self.raw_output.argmax()
        confidence_score = self.raw_output[output_label]
        if isinstance(confidence_score, torch.Tensor):
            confidence_score = confidence_score.item()
        output, color = self._processed_output
        # concatenate with label and convert confidence score to percent, like '33%'
        output_str = f"{output} ({confidence_score:.0%})"
        return utils.color_text(output_str, color=color, method=color_method)
