"""
Augmenter Class
===================
"""
import random

import tqdm

from textattack.constraints import PreTransformationConstraint
from textattack.metrics.quality_metrics import Perplexity, USEMetric
from textattack.shared import AttackedText, utils


class Augmenter:
    """A class for performing data augmentation using TextAttack.

    Returns all possible transformations for a given string. Currently only
        supports transformations which are word swaps.

    Args:
        transformation (textattack.Transformation): the transformation
            that suggests new texts from an input.
        constraints: (list(textattack.Constraint)): constraints
            that each transformation must meet
        pct_words_to_swap: (float): [0., 1.], percentage of words to swap per augmented example
        transformations_per_example: (int): Maximum number of augmentations
            per input
        high_yield: Whether to return a set of augmented texts that will be relatively similar, or to return only a
            single one.
        fast_augment: Stops additional transformation runs when number of successful augmentations reaches
            transformations_per_example
        advanced_metrics: return perplexity and USE Score of augmentation

    Example::
        >>> from textattack.transformations import WordSwapRandomCharacterDeletion, WordSwapQWERTY, CompositeTransformation
        >>> from textattack.constraints.pre_transformation import RepeatModification, StopwordModification
        >>> from textattack.augmentation import Augmenter

        >>> transformation = CompositeTransformation([WordSwapRandomCharacterDeletion(), WordSwapQWERTY()])
        >>> constraints = [RepeatModification(), StopwordModification()]

        >>> # initiate augmenter
        >>> augmenter = Augmenter(
        ...     transformation=transformation,
        ...     constraints=constraints,
        ...     pct_words_to_swap=0.5,
        ...     transformations_per_example=3
        ... )

        >>> # additional parameters can be modified if not during initiation
        >>> augmenter.enable_advanced_metrics = True
        >>> augmenter.fast_augment = True
        >>> augmenter.high_yield = True

        >>> s = 'What I cannot create, I do not understand.'
        >>> results = augmenter.augment(s)

        >>> augmentations = results[0]
        >>> perplexity_score = results[1]
        >>> use_score = results[2]
    """

    def __init__(
        self,
        transformation,
        constraints=[],
        pct_words_to_swap=0.1,
        transformations_per_example=1,
        high_yield=False,
        fast_augment=False,
        enable_advanced_metrics=False,
    ):
        assert (
            transformations_per_example > 0
        ), "transformations_per_example must be a positive integer"
        assert 0.0 <= pct_words_to_swap <= 1.0, "pct_words_to_swap must be in [0., 1.]"
        self.transformation = transformation
        self.pct_words_to_swap = pct_words_to_swap
        self.transformations_per_example = transformations_per_example

        self.constraints = []
        self.pre_transformation_constraints = []
        self.high_yield = high_yield
        self.fast_augment = fast_augment
        self.advanced_metrics = enable_advanced_metrics
        for constraint in constraints:
            if isinstance(constraint, PreTransformationConstraint):
                self.pre_transformation_constraints.append(constraint)
            else:
                self.constraints.append(constraint)

    def _filter_transformations(self, transformed_texts, current_text, original_text):
        """Filters a list of ``AttackedText`` objects to include only the ones
        that pass ``self.constraints``."""
        for C in self.constraints:
            if len(transformed_texts) == 0:
                break
            if C.compare_against_original:
                if not original_text:
                    raise ValueError(
                        f"Missing `original_text` argument when constraint {type(C)} is set to compare against "
                        f"`original_text` "
                    )

                transformed_texts = C.call_many(transformed_texts, original_text)
            else:
                transformed_texts = C.call_many(transformed_texts, current_text)
        return transformed_texts

    def augment(self, text):
        """Returns all possible augmentations of ``text`` according to
        ``self.transformation``."""
        attacked_text = AttackedText(text)
        original_text = attacked_text
        all_transformed_texts = set()
        num_words_to_swap = max(
            int(self.pct_words_to_swap * len(attacked_text.words)), 1
        )
        augmentation_results = []
        for _ in range(self.transformations_per_example):
            current_text = attacked_text
            words_swapped = len(current_text.attack_attrs["modified_indices"])

            while words_swapped < num_words_to_swap:
                transformed_texts = self.transformation(
                    current_text, self.pre_transformation_constraints
                )

                # Get rid of transformations we already have
                transformed_texts = [
                    t for t in transformed_texts if t not in all_transformed_texts
                ]

                # Filter out transformations that don't match the constraints.
                transformed_texts = self._filter_transformations(
                    transformed_texts, current_text, original_text
                )

                # if there's no more transformed texts after filter, terminate
                if not len(transformed_texts):
                    break

                # look for all transformed_texts that has enough words swapped
                if self.high_yield or self.fast_augment:
                    ready_texts = [
                        text
                        for text in transformed_texts
                        if len(text.attack_attrs["modified_indices"])
                        >= num_words_to_swap
                    ]
                    for text in ready_texts:
                        all_transformed_texts.add(text)
                    unfinished_texts = [
                        text for text in transformed_texts if text not in ready_texts
                    ]

                    if len(unfinished_texts):
                        current_text = random.choice(unfinished_texts)
                    else:
                        # no need for further augmentations if all of transformed_texts meet `num_words_to_swap`
                        break
                else:
                    current_text = random.choice(transformed_texts)

                # update words_swapped based on modified indices
                words_swapped = max(
                    len(current_text.attack_attrs["modified_indices"]),
                    words_swapped + 1,
                )

            all_transformed_texts.add(current_text)

            # when with fast_augment, terminate early if there're enough successful augmentations
            if (
                self.fast_augment
                and len(all_transformed_texts) >= self.transformations_per_example
            ):
                if not self.high_yield:
                    all_transformed_texts = random.sample(
                        all_transformed_texts, self.transformations_per_example
                    )
                break

        perturbed_texts = sorted([at.printable_text() for at in all_transformed_texts])

        if self.advanced_metrics:
            for transformed_texts in all_transformed_texts:
                augmentation_results.append(
                    AugmentationResult(original_text, transformed_texts)
                )
            perplexity_stats = Perplexity().calculate(augmentation_results)
            use_stats = USEMetric().calculate(augmentation_results)
            return perturbed_texts, perplexity_stats, use_stats

        return perturbed_texts

    def augment_many(self, text_list, show_progress=False):
        """Returns all possible augmentations of a list of strings according to
        ``self.transformation``.

        Args:
            text_list (list(string)): a list of strings for data augmentation
        Returns a list(string) of augmented texts.
        :param show_progress: show process during augmentation
        """
        if show_progress:
            text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
        return [self.augment(text) for text in text_list]

    def augment_text_with_ids(self, text_list, id_list, show_progress=True):
        """Supplements a list of text with more text data.

        Returns the augmented text along with the corresponding IDs for
        each augmented example.
        """
        if len(text_list) != len(id_list):
            raise ValueError("List of text must be same length as list of IDs")
        if self.transformations_per_example == 0:
            return text_list, id_list
        all_text_list = []
        all_id_list = []
        if show_progress:
            text_list = tqdm.tqdm(text_list, desc="Augmenting data...")
        for text, _id in zip(text_list, id_list):
            all_text_list.append(text)
            all_id_list.append(_id)
            augmented_texts = self.augment(text)
            all_text_list.extend
            all_text_list.extend([text] + augmented_texts)
            all_id_list.extend([_id] * (1 + len(augmented_texts)))
        return all_text_list, all_id_list

    def __repr__(self):
        main_str = "Augmenter" + "("
        lines = []
        # self.transformation
        lines.append(utils.add_indent(f"(transformation):  {self.transformation}", 2))
        # self.constraints
        constraints_lines = []
        constraints = self.constraints + self.pre_transformation_constraints
        if len(constraints):
            for i, constraint in enumerate(constraints):
                constraints_lines.append(utils.add_indent(f"({i}): {constraint}", 2))
            constraints_str = utils.add_indent("\n" + "\n".join(constraints_lines), 2)
        else:
            constraints_str = "None"
        lines.append(utils.add_indent(f"(constraints): {constraints_str}", 2))
        main_str += "\n  " + "\n  ".join(lines) + "\n"
        main_str += ")"
        return main_str


class AugmentationResult:
    def __init__(self, text1, text2):
        self.original_result = self.tempResult(text1)
        self.perturbed_result = self.tempResult(text2)

    class tempResult:
        def __init__(self, text):
            self.attacked_text = text
