"""

Particle Swarm Optimization
====================================

Reimplementation of search method from Word-level Textual Adversarial
Attacking as Combinatorial Optimization by Zang et.

al
`<https://www.aclweb.org/anthology/2020.acl-main.540.pdf>`_
`<https://github.com/thunlp/SememePSO-Attack>`_
"""
import copy

import numpy as np

from textattack.goal_function_results import GoalFunctionResultStatus
from textattack.search_methods import PopulationBasedSearch, PopulationMember
from textattack.shared import utils
from textattack.shared.validators import transformation_consists_of_word_swaps


class ParticleSwarmOptimization(PopulationBasedSearch):
    """Attacks a model with word substiutitions using a Particle Swarm
    Optimization (PSO) algorithm. Some key hyper-parameters are setup according
    to the original paper:

    "We adjust PSO on the validation set of SST and set ω_1 as 0.8 and ω_2 as 0.2.
    We set the max velocity of the particles V_{max} to 3, which means the changing
    probability of the particles ranges from 0.047 (sigmoid(-3)) to 0.953 (sigmoid(3))."

    Args:
        pop_size (:obj:`int`, optional): The population size. Defaults to 60.
        max_iters (:obj:`int`, optional): The maximum number of iterations to use. Defaults to 20.
        post_turn_check (:obj:`bool`, optional): If `True`, check if new position reached by moving passes the constraints. Defaults to `True`
        max_turn_retries (:obj:`bool`, optional): Maximum number of movement retries if new position after turning fails to pass the constraints.
            Applied only when `post_movement_check` is set to `True`.
            Setting it to 0 means we immediately take the old position as the new position upon failure.
    """

    def __init__(
        self, pop_size=60, max_iters=20, post_turn_check=True, max_turn_retries=20
    ):
        self.max_iters = max_iters
        self.pop_size = pop_size
        self.post_turn_check = post_turn_check
        self.max_turn_retries = 20

        self._search_over = False
        self.omega_1 = 0.8
        self.omega_2 = 0.2
        self.c1_origin = 0.8
        self.c2_origin = 0.2
        self.v_max = 3.0

    def _perturb(self, pop_member, original_result):
        """Perturb `pop_member` in-place.

        Replaces a word at a random in `pop_member` with replacement word that maximizes increase in score.
        Args:
            pop_member (PopulationMember): The population member being perturbed.
            original_result (GoalFunctionResult): Result of original sample being attacked
        Returns:
            `True` if perturbation occured. `False` if not.
        """
        # TODO: Below is very slow and is the main cause behind memory build up + slowness
        best_neighbors, prob_list = self._get_best_neighbors(
            pop_member.result, original_result
        )
        random_result = np.random.choice(best_neighbors, 1, p=prob_list)[0]

        if random_result == pop_member.result:
            return False
        else:
            pop_member.attacked_text = random_result.attacked_text
            pop_member.result = random_result
            return True

    def _equal(self, a, b):
        return -self.v_max if a == b else self.v_max

    def _turn(self, source_text, target_text, prob, original_text):
        """
        Based on given probabilities, "move" to `target_text` from `source_text`
        Args:
            source_text (PopulationMember): Text we start from.
            target_text (PopulationMember): Text we want to move to.
            prob (np.array[float]): Turn probability for each word.
            original_text (AttackedText): Original text for constraint check if `self.post_turn_check=True`.
        Returns:
            New `Position` that we moved to (or if we fail to move, same as `source_text`)
        """
        assert len(source_text.words) == len(
            target_text.words
        ), "Word length mismatch for turn operation."
        assert len(source_text.words) == len(
            prob
        ), "Length mismatch for words and probability list."
        len_x = len(source_text.words)

        num_tries = 0
        passed_constraints = False
        while num_tries < self.max_turn_retries + 1:
            indices_to_replace = []
            words_to_replace = []
            for i in range(len_x):
                if np.random.uniform() < prob[i]:
                    indices_to_replace.append(i)
                    words_to_replace.append(target_text.words[i])
            new_text = source_text.attacked_text.replace_words_at_indices(
                indices_to_replace, words_to_replace
            )
            indices_to_replace = set(indices_to_replace)
            new_text.attack_attrs["modified_indices"] = (
                source_text.attacked_text.attack_attrs["modified_indices"]
                - indices_to_replace
            ) | (
                target_text.attacked_text.attack_attrs["modified_indices"]
                & indices_to_replace
            )
            if "last_transformation" in source_text.attacked_text.attack_attrs:
                new_text.attack_attrs[
                    "last_transformation"
                ] = source_text.attacked_text.attack_attrs["last_transformation"]

            if not self.post_turn_check or (new_text.words == source_text.words):
                break

            if "last_transformation" in new_text.attack_attrs:
                passed_constraints = self._check_constraints(
                    new_text, source_text.attacked_text, original_text=original_text
                )
            else:
                passed_constraints = True

            if passed_constraints:
                break

            num_tries += 1

        if self.post_turn_check and not passed_constraints:
            # If we cannot find a turn that passes the constraints, we do not move.
            return source_text
        else:
            return PopulationMember(new_text)

    def _get_best_neighbors(self, current_result, original_result):
        """For given current text, find its neighboring texts that yields
        maximum improvement (in goal function score) for each word.

        Args:
            current_result (GoalFunctionResult): `GoalFunctionResult` of current text
            original_result (GoalFunctionResult): `GoalFunctionResult` of original text.
        Returns:
            best_neighbors (list[GoalFunctionResult]): Best neighboring text for each word
            prob_list (list[float]): discrete probablity distribution for sampling a neighbor from `best_neighbors`
        """
        current_text = current_result.attacked_text
        neighbors_list = [[] for _ in range(len(current_text.words))]
        transformed_texts = self.get_transformations(
            current_text, original_text=original_result.attacked_text
        )
        for transformed_text in transformed_texts:
            diff_idx = next(
                iter(transformed_text.attack_attrs["newly_modified_indices"])
            )
            neighbors_list[diff_idx].append(transformed_text)

        best_neighbors = []
        score_list = []
        for i in range(len(neighbors_list)):
            if not neighbors_list[i]:
                best_neighbors.append(current_result)
                score_list.append(0)
                continue

            neighbor_results, self._search_over = self.get_goal_results(
                neighbors_list[i]
            )
            if not len(neighbor_results):
                best_neighbors.append(current_result)
                score_list.append(0)
            else:
                neighbor_scores = np.array([r.score for r in neighbor_results])
                score_diff = neighbor_scores - current_result.score
                best_idx = np.argmax(neighbor_scores)
                best_neighbors.append(neighbor_results[best_idx])
                score_list.append(score_diff[best_idx])

        prob_list = normalize(score_list)

        return best_neighbors, prob_list

    def _initialize_population(self, initial_result, pop_size):
        """
        Initialize a population of size `pop_size` with `initial_result`
        Args:
            initial_result (GoalFunctionResult): Original text
            pop_size (int): size of population
        Returns:
            population as `list[PopulationMember]`
        """
        best_neighbors, prob_list = self._get_best_neighbors(
            initial_result, initial_result
        )
        population = []
        for _ in range(pop_size):
            # Mutation step
            random_result = np.random.choice(best_neighbors, 1, p=prob_list)[0]
            population.append(
                PopulationMember(random_result.attacked_text, random_result)
            )
        return population

    def perform_search(self, initial_result):
        self._search_over = False
        population = self._initialize_population(initial_result, self.pop_size)
        # Initialize  up velocities of each word for each population
        v_init = np.random.uniform(-self.v_max, self.v_max, self.pop_size)
        velocities = np.array(
            [
                [v_init[t] for _ in range(initial_result.attacked_text.num_words)]
                for t in range(self.pop_size)
            ]
        )

        global_elite = max(population, key=lambda x: x.score)
        if (
            self._search_over
            or global_elite.result.goal_status == GoalFunctionResultStatus.SUCCEEDED
        ):
            return global_elite.result

        local_elites = copy.copy(population)

        # start iterations
        for i in range(self.max_iters):
            omega = (self.omega_1 - self.omega_2) * (
                self.max_iters - i
            ) / self.max_iters + self.omega_2
            C1 = self.c1_origin - i / self.max_iters * (self.c1_origin - self.c2_origin)
            C2 = self.c2_origin + i / self.max_iters * (self.c1_origin - self.c2_origin)
            P1 = C1
            P2 = C2

            for k in range(len(population)):
                # calculate the probability of turning each word
                pop_mem_words = population[k].words
                local_elite_words = local_elites[k].words
                assert len(pop_mem_words) == len(
                    local_elite_words
                ), "PSO word length mismatch!"

                for d in range(len(pop_mem_words)):
                    velocities[k][d] = omega * velocities[k][d] + (1 - omega) * (
                        self._equal(pop_mem_words[d], local_elite_words[d])
                        + self._equal(pop_mem_words[d], global_elite.words[d])
                    )
                turn_prob = utils.sigmoid(velocities[k])

                if np.random.uniform() < P1:
                    # Move towards local elite
                    population[k] = self._turn(
                        local_elites[k],
                        population[k],
                        turn_prob,
                        initial_result.attacked_text,
                    )

                if np.random.uniform() < P2:
                    # Move towards global elite
                    population[k] = self._turn(
                        global_elite,
                        population[k],
                        turn_prob,
                        initial_result.attacked_text,
                    )

            # Check if there is any successful attack in the current population
            pop_results, self._search_over = self.get_goal_results(
                [p.attacked_text for p in population]
            )
            if self._search_over:
                # if `get_goal_results` gets cut short by query budget, resize population
                population = population[: len(pop_results)]
            for k in range(len(pop_results)):
                population[k].result = pop_results[k]

            top_member = max(population, key=lambda x: x.score)
            if (
                self._search_over
                or top_member.result.goal_status == GoalFunctionResultStatus.SUCCEEDED
            ):
                return top_member.result

            # Mutation based on the current change rate
            for k in range(len(population)):
                change_ratio = initial_result.attacked_text.words_diff_ratio(
                    population[k].attacked_text
                )
                # Referred from the original source code
                p_change = 1 - 2 * change_ratio
                if np.random.uniform() < p_change:
                    self._perturb(population[k], initial_result)

                if self._search_over:
                    break

            # Check if there is any successful attack in the current population
            top_member = max(population, key=lambda x: x.score)
            if (
                self._search_over
                or top_member.result.goal_status == GoalFunctionResultStatus.SUCCEEDED
            ):
                return top_member.result

            # Update the elite if the score is increased
            for k in range(len(population)):
                if population[k].score > local_elites[k].score:
                    local_elites[k] = copy.copy(population[k])

            if top_member.score > global_elite.score:
                global_elite = copy.copy(top_member)

        return global_elite.result

    def check_transformation_compatibility(self, transformation):
        """The genetic algorithm is specifically designed for word
        substitutions."""
        return transformation_consists_of_word_swaps(transformation)

    @property
    def is_black_box(self):
        return True

    def extra_repr_keys(self):
        return ["pop_size", "max_iters", "post_turn_check", "max_turn_retries"]


def normalize(n):
    n = np.array(n)
    n[n < 0] = 0
    s = np.sum(n)
    if s == 0:
        return np.ones(len(n)) / len(n)
    else:
        return n / s
