"""Single-file implementation of the GCG attack with additional options.

@article{zou2023universal,
  title={Universal and transferable adversarial attacks on aligned language models},
  author={Zou, Andy and Wang, Zifan and Carlini, Nicholas and Nasr, Milad and Kolter, J Zico and Fredrikson, Matt},
  journal={arXiv preprint arXiv:2307.15043},
  year={2023}
}

Extensively tested against a variety of models, including:
    cais/zephyr_7b_r2d2
    ContinuousAT/Llama-2-7B-CAT
    ContinuousAT/Phi-CAT
    ContinuousAT/Zephyr-CAT
    google/gemma-2-2b-it
    GraySwanAI/Llama-3-8B-Instruct-RR
    GraySwanAI/Mistral-7B-Instruct-RR
    HuggingFaceH4/zephyr-7b-beta
    meta-llama/Llama-2-7b-chat-hf
    meta-llama/Meta-Llama-3.1-8B-Instruct
    microsoft/Phi-3-mini-4k-instruct
    mistralai/Mistral-7B-Instruct-v0.3
    qwen/Qwen2-7B-Instruct

The implementation is inspired by nanoGCG, but fixes several issues in nanoGCG,
mostly related to tokenization.
"""
import gc
import logging
import random
import sys
import time

from dataclasses import dataclass, field
from functools import partial

import torch
import transformers
from torch import Tensor
from tqdm import trange
from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizer

from src.lm_utils import (filter_suffix, generate_ragged_batched,
                          get_disallowed_ids, prepare_conversation,
                          with_max_batchsize, TokenMergeError)

from src.dataset import PromptDataset
from .attack import Attack, AttackResult, GenerationConfig, SingleAttackRunResult, AttackStepResult


@dataclass
class GCGConfig:
    name: str = "gcg"
    type: str = "discrete"
    version: str = ""
    placement: str = "suffix"
    generation_config: GenerationConfig = field(default_factory=GenerationConfig)
    num_steps: int = 250
    seed: int = 0
    optim_str_init: str = "x x x x x x x x x x x x x x x x x x x x"
    search_width: int = 512
    topk: int = 256
    n_replace: int = 1
    buffer_size: int = 0
    use_constrained_gradient: bool = False
    mellowmax_alpha: float = 1.0
    early_stop: bool = False
    use_prefix_cache: bool = True
    allow_non_ascii: bool = False
    allow_special: bool = False
    filter_ids: bool = True
    verbosity: str = "WARNING"
    token_selection: str = "default"
    grow_target: bool = False


def compute_loss(shift_logits: Tensor, shift_labels: Tensor) -> Tensor:
    """Computes the loss based on the specified loss type.

    Args:
        shift_logits: Tensor of shape (batch_size, seq_len, vocab_size)
        shift_labels: Tensor of shape (batch_size, seq_len)
        loss_type: Type of loss to compute ('mellowmax', 'cw', 'ce', 'entropy')
        mellowmax_alpha: Alpha parameter for mellowmax loss

    Returns:
        loss: Tensor of shape (batch_size,)

    Raises:
        NotImplementedError: If the loss type is not implemented
    """
    loss = torch.nn.functional.cross_entropy(
        shift_logits.view(-1, shift_logits.size(-1)),
        shift_labels.view(-1),
        reduction="none",
    )
    loss = loss.view(shift_logits.shape[0], -1).mean(dim=-1)
    return loss


class GCGAttack(Attack):
    def __init__(self, config: GCGConfig):
        super().__init__(config)
        self.logger = logging.getLogger("nanogcg")
        if not self.logger.hasHandlers():
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                "%(asctime)s [%(filename)s:%(lineno)d] %(message)s",
                datefmt="%Y-%m-%d %H:%M:%S",
            )
            handler.setFormatter(formatter)
            self.logger.addHandler(handler)
            self.logger.setLevel(logging.INFO)

    def run(self, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, dataset: PromptDataset) -> AttackResult:
        self.not_allowed_ids = get_disallowed_ids(tokenizer, self.config.allow_non_ascii, self.config.allow_special).to(model.device)
        runs = []
        for conversation in dataset:
            t0 = time.time()
            try:
                attack_conversation = [
                    {"role": "user", "content": conversation[0]["content"] + self.config.optim_str_init},
                    {"role": "assistant", "content": conversation[1]["content"]},
                ]
                pre_ids, attack_prefix_ids, prompt_ids, attack_suffix_ids, post_ids, target_ids = prepare_conversation(tokenizer, conversation, attack_conversation)[0]
            except TokenMergeError:
                attack_conversation = [
                    {"role": "user", "content": conversation[0]["content"] + " " + self.config.optim_str_init},
                    {"role": "assistant", "content": conversation[1]["content"]},
                ]
                pre_ids, attack_prefix_ids, prompt_ids, attack_suffix_ids, post_ids, target_ids = prepare_conversation(tokenizer, conversation, attack_conversation)[0]

            pre_ids = pre_ids.unsqueeze(0).to(model.device)
            # attack_prefix_ids = attack_prefix_ids.unsqueeze(0).to(model.device)
            prompt_ids = prompt_ids.unsqueeze(0).to(model.device)
            pre_prompt_ids = torch.cat([pre_ids, prompt_ids], dim=1)
            attack_ids = attack_suffix_ids.unsqueeze(0).to(model.device)
            post_ids = post_ids.unsqueeze(0).to(model.device)
            target_ids = target_ids.unsqueeze(0).to(model.device)

            # Embed everything that doesn't get optimized
            embedding_layer = model.get_input_embeddings()
            pre_prompt_embeds, post_embeds, target_embeds = [
                embedding_layer(ids) for ids in (pre_prompt_ids, post_ids, target_ids)
            ]

            # Compute the KV Cache for tokens that appear before the optimized tokens
            if self.config.use_prefix_cache and "gemma-2" not in model.name_or_path:
                with torch.no_grad():
                    self.prefix_cache = DynamicCache()
                    output = model(inputs_embeds=pre_prompt_embeds, past_key_values=self.prefix_cache, use_cache=True)
                    self.prefix_cache = output.past_key_values
            else:
                self.prefix_cache = None

            self.target_ids = target_ids
            self.pre_prompt_embeds = pre_prompt_embeds
            self.post_embeds = post_embeds
            self.target_embeds = target_embeds

            if self.config.grow_target:
                self.target_length = 1
            else:
                self.target_length = target_ids.size(1)
            # Initialize the attack buffer
            buffer = self.init_buffer(model, attack_ids)
            optim_ids = buffer.get_best_ids()
            token_selection = SubstitutionSelectionStrategy(self.prefix_cache, self.pre_prompt_embeds, self.post_embeds, self.target_embeds, self.target_ids, self.not_allowed_ids)

            losses = []
            times = []
            optim_strings = []
            self.stop_flag = False
            current_loss = buffer.get_lowest_loss()
            for _ in (pbar := trange(self.config.num_steps, file=sys.stdout)):
                t0a = time.time()
                token_selection.target_ids = self.target_ids[:, :self.target_length]
                token_selection.target_embeds = self.target_embeds[:, :self.target_length]
                # Compute the token gradient
                sampled_ids, sampled_ids_pos, grad = token_selection(
                    optim_ids.squeeze(0),
                    model,
                    self.config.search_width,
                    self.config.topk,
                    self.config.n_replace,
                    not_allowed_ids=self.not_allowed_ids,
                )
                with torch.no_grad():
                    # Sample candidate token sequences
                    if self.config.filter_ids:
                        # We're trying to be as strict as possible here, so we filter
                        # the entire prompt, not just the attack sequence in an isolated
                        # way. This is because the prompt and attack can affect each
                        # other's tokenization in some cases.
                        idx = filter_suffix(
                            tokenizer,
                            conversation,
                            [[None, sampled_ids.cpu()]],
                        )
                        sampled_ids = sampled_ids[idx]
                        sampled_ids_pos = sampled_ids_pos[idx]

                    compute_loss_fn = partial(self.compute_candidates_loss, model)
                    loss, acc = with_max_batchsize(compute_loss_fn, sampled_ids)

                    current_loss = loss.min().item()
                    optim_ids = sampled_ids[loss.argmin()].unsqueeze(0)
                    if self.config.grow_target and acc[loss.argmin()]:
                        self.target_length += 1
                    # Update the buffer based on the loss
                    losses.append(current_loss)
                    times.append(time.time() - t0a)
                    if buffer.size == 0 or current_loss < buffer.get_highest_loss():
                        buffer.add(current_loss, optim_ids)

                optim_ids = buffer.get_best_ids()
                optim_str = tokenizer.batch_decode(optim_ids)[0]
                optim_strings.append(optim_str)
                pbar.set_postfix({"Loss": current_loss, "# TGT Toks": self.target_length, "Best Attack": optim_str[:80]})

                if self.stop_flag:
                    self.logger.info("Early stopping due to finding a perfect match.")
                    break

            token_list = []
            attack_conversations = []
            for attack in optim_strings:
                attack_conversation = [
                    {"role": "user", "content": conversation[0]["content"] + attack},
                    {"role": "assistant", "content": ""},
                ]
                tokens = prepare_conversation(tokenizer, conversation, attack_conversation)[0]
                token_list.append(torch.cat(tokens[:5]))
                attack_conversations.append(attack_conversation)
            batch_completions = generate_ragged_batched(
                model,
                tokenizer,
                token_list=token_list,
                max_new_tokens=self.config.generation_config.max_new_tokens,
                temperature=self.config.generation_config.temperature,
                top_p=self.config.generation_config.top_p,
                top_k=self.config.generation_config.top_k,
                num_return_sequences=self.config.generation_config.num_return_sequences,
            )  # (N_steps, N_return_sequences, T)
            steps = []
            t1 = time.time()
            for i in range(len(optim_strings)):
                step = AttackStepResult(
                    step=i,
                    model_completions=batch_completions[i],
                    time_taken=times[i],
                    loss=losses[i],
                    model_input=attack_conversations[i],
                    model_input_tokens=token_list[i].tolist(),
                )
                steps.append(step)

            run = SingleAttackRunResult(
                original_prompt=conversation,
                steps=steps,
                total_time=t1 - t0,
            )
            runs.append(run)
        return AttackResult(runs=runs)

    def init_buffer(self, model, init_buffer_ids):
        config = self.config

        # Create the attack buffer and initialize the buffer ids
        buffer = AttackBuffer(config.buffer_size)
        true_buffer_size = max(1, config.buffer_size)

        # Compute the loss on the initial buffer entries
        compute_loss_fn = partial(self.compute_candidates_loss, model)
        init_buffer_losses, _ = with_max_batchsize(compute_loss_fn, init_buffer_ids)

        # Populate the buffer
        for i in range(true_buffer_size):
            buffer.add(init_buffer_losses[i], init_buffer_ids[[i]])
        return buffer

    @torch.no_grad()
    def compute_candidates_loss(
        self,
        model: transformers.PreTrainedModel,
        attack_ids: Tensor,
    ) -> Tensor:
        """Computes the GCG loss on all candidate token id sequences.

        Args:
            model : transformers.PreTrainedModel
                the model to compute the loss with respect to
            input_embeds : Tensor, shape = (B, T, D)
                the embeddings of the candidate sequences to evaluate

        Returns:
            loss : Tensor, shape = (B,)
                the GCG loss on all candidate sequences
        """

        all_loss = []
        all_acc = []
        B = attack_ids.shape[0]
        T = self.pre_prompt_embeds.size(1)
        if self.prefix_cache:
            input_embeds = torch.cat(
                [
                    model.get_input_embeddings()(attack_ids),
                    self.post_embeds.repeat(B, 1, 1),
                    self.target_embeds[:, :self.target_length].repeat(B, 1, 1),
                ],
                dim=1,
            )
            for i, kc in enumerate(self.prefix_cache.key_cache):
                self.prefix_cache.key_cache[i] = kc[:1, :, :T].expand(B, -1, -1, -1)
            for i, vc in enumerate(self.prefix_cache.value_cache):
                self.prefix_cache.value_cache[i] = vc[:1, :, :T].expand(B, -1, -1, -1)
            outputs = model(
                inputs_embeds=input_embeds,
                past_key_values=self.prefix_cache,
                use_cache=True,
            )
            for i, kc in enumerate(self.prefix_cache.key_cache):
                self.prefix_cache.key_cache[i] = kc[:1]
            for i, vc in enumerate(self.prefix_cache.value_cache):
                self.prefix_cache.value_cache[i] = vc[:1]
            self.prefix_cache.crop(T)
        else:
            input_embeds = torch.cat(
                [
                    self.pre_prompt_embeds.repeat(B, 1, 1),
                    model.get_input_embeddings()(attack_ids),
                    self.post_embeds.repeat(B, 1, 1),
                    self.target_embeds[:, :self.target_length].repeat(B, 1, 1),
                ],
                dim=1,
            )
            outputs = model(inputs_embeds=input_embeds)

        logits = outputs.logits
        tmp = logits.size(1) - self.target_ids[:, :self.target_length].size(1)
        shift_logits = logits[..., tmp - 1 : -1, :].contiguous()
        shift_labels = self.target_ids[:, :self.target_length].repeat(B, 1)

        loss = compute_loss(shift_logits, shift_labels)

        acc = (shift_logits.argmax(-1) == shift_labels).all(-1)  # (B, T) -> (B,)
        loss = loss.view(B, -1).mean(dim=-1)
        all_loss.append(loss)
        all_acc.append(acc)

        if self.config.early_stop:
            if acc.any().item():
                self.stop_flag = True

        del outputs
        gc.collect()
        torch.cuda.empty_cache()

        return torch.cat(all_loss, dim=0), torch.cat(all_acc, dim=0)


class AttackBuffer:
    def __init__(self, size: int):
        self.buffer = []  # elements are (loss: float, optim_ids: Tensor)
        self.size = size

    def add(self, loss: float, optim_ids: Tensor) -> None:
        if self.size == 0:
            self.buffer = [(loss, optim_ids)]
            return

        if len(self.buffer) < self.size:
            self.buffer.append((loss, optim_ids))
        else:
            self.buffer[-1] = (loss, optim_ids)

        self.buffer.sort(key=lambda x: x[0])

    def get_best_ids(self) -> Tensor:
        return self.buffer[0][1]

    def get_lowest_loss(self) -> float:
        return self.buffer[0][0]

    def get_highest_loss(self) -> float:
        return self.buffer[-1][0]


class SubstitutionSelectionStrategy:
    def __init__(self, config: GCGConfig, prefix_cache: list[tuple[Tensor, Tensor]], pre_prompt_embeds: Tensor, post_embeds: Tensor, target_embeds: Tensor, target_ids: Tensor, not_allowed_ids: Tensor):
        self.config = config
        self.prefix_cache = prefix_cache
        self.pre_prompt_embeds = pre_prompt_embeds
        self.post_embeds = post_embeds
        self.target_embeds = target_embeds
        self.target_ids = target_ids
        self.not_allowed_ids = not_allowed_ids
        self.grad_buffer = None

    def __call__(
        self,
        ids: Tensor,
        model: transformers.PreTrainedModel,
        search_width: int,
        topk: int,
        n_replace: int,
        not_allowed_ids: Tensor,
        *args,
        **kwargs,
    ):
        return self._sample_ids_from_grad(
            ids,
            model,
            search_width,
            topk,
            n_replace,
            *args,
            **kwargs,
        )

    def _sample_ids_from_grad(
        self,
        ids: Tensor,
        model: transformers.PreTrainedModel,
        search_width: int,
        topk: int = 256,
        n_replace: int = 1,
    ):
        """Returns `search_width` combinations of token ids based on the token gradient.
        Original GCG does this.

        Args:
            ids : Tensor, shape = (n_optim_ids)
                the sequence of token ids that are being optimized
            grad : Tensor, shape = (n_optim_ids, vocab_size)
                the gradient of the GCG loss computed with respect to the one-hot token embeddings
            search_width : int
                the number of candidate sequences to return
            topk : int
                the topk to be used when sampling from the gradient
            n_replace: int
                the number of token positions to update per sequence
            not_allowed_ids: Tensor, shape = (n_ids)
                the token ids that should not be used in optimization

        Returns:
            sampled_ids : Tensor, shape = (search_width, n_optim_ids)
                sampled token ids
        """
        # Initial gradient computation
        grad = self.compute_token_gradient(ids.unsqueeze(0), model).squeeze(0)  # (n_optim_ids, vocab_size)

        n_smoothing = self.config.grad_smoothing
        if n_smoothing > 1:
            allowed_ids = [i for i in range(self.target_embeds.size(-1)) if i not in self.not_allowed_ids]

            # Get batch size for gradient smoothing
            batch_size = 64
            total_samples = n_smoothing - 1

            all_grads = grad.clone()

            # Process in batches
            for batch_start in range(0, total_samples, batch_size):
                current_batch_size = min(batch_size, total_samples - batch_start)

                grad_ids_batch = ids.clone().unsqueeze(0).repeat(current_batch_size, 1)  # (batch_size, n_optim_ids)

                random_positions = torch.randint(0, grad_ids_batch.shape[1], (current_batch_size, 1), device=ids.device)
                random_indices = torch.tensor([random.choice(allowed_ids) for _ in range(current_batch_size)],
                                             device=ids.device).unsqueeze(1)
                grad_ids_batch.scatter_(1, random_positions, random_indices)
                batch_grads = self.compute_token_gradient(grad_ids_batch, model).detach()
                all_grads += batch_grads.sum(0)
            grad = all_grads / n_smoothing
        grad_momentum = self.config.grad_momentum
        if grad_momentum > 0.0:
            if self.grad_buffer is None:
                self.grad_buffer = grad
            else:
                self.grad_buffer = grad_momentum * self.grad_buffer + (1 - grad_momentum) * grad
            grad = self.grad_buffer
        n_optim_tokens = len(ids)
        original_ids = ids.repeat(search_width, 1)

        if self.not_allowed_ids is not None:
            grad[:, self.not_allowed_ids.to(grad.device)] = float("inf")
        # (n_optim_ids, topk)
        topk_ids = grad.topk(topk, dim=1, largest=False, sorted=False).indices

        sampled_ids_pos = torch.randint(
            0, n_optim_tokens, (search_width, n_replace), device=grad.device
        )  # (search_width, n_replace)
        sampled_topk_idx = torch.randint(
            0, topk, (search_width, n_replace, 1), device=grad.device
        )

        sampled_ids_val = (
            topk_ids[sampled_ids_pos].gather(2, sampled_topk_idx).squeeze(2)
        )  # (search_width, n_replace)

        new_ids = original_ids.scatter_(
            1, sampled_ids_pos, sampled_ids_val
        )  # (search_width, n_optim_ids)

        return new_ids, sampled_ids_pos, grad

    def compute_token_gradient(
        self,
        optim_ids: Tensor,
        model: transformers.PreTrainedModel,
    ) -> Tensor:
        """Computes the gradient of the GCG loss w.r.t the one-hot token matrix.

        Args:
        optim_ids : Tensor, shape = (N, n_optim_ids)
            the sequence of token ids that are being optimized
        model : transformers.PreTrainedModel
            the model to compute the gradient with respect to

        Returns:
            grad : Tensor, shape = (N, n_optim_ids, vocab_size)
                the gradient of the GCG loss computed with respect to the one-hot token embeddings
        """
        assert optim_ids.ndim == 2
        embedding_layer = model.get_input_embeddings()

        # Create the one-hot encoding matrix of our optimized token ids
        optim_ids_onehot = torch.nn.functional.one_hot(
            optim_ids, num_classes=embedding_layer.num_embeddings
        )
        optim_ids_onehot = optim_ids_onehot.to(dtype=model.dtype, device=model.device)
        optim_ids_onehot.requires_grad_()

        # (1, num_optim_tokens, vocab_size) @ (vocab_size, embed_dim) -> (1, num_optim_tokens, embed_dim)
        if self.config.use_constrained_gradient:
            optim_embeds = (
                optim_ids_onehot / optim_ids_onehot.sum(dim=-1, keepdim=True)
            ) @ embedding_layer.weight
        else:
            optim_embeds = optim_ids_onehot @ embedding_layer.weight

        B = optim_embeds.shape[0]
        if self.prefix_cache:
            T = self.pre_prompt_embeds.shape[1]
            input_embeds = torch.cat(
                [optim_embeds, self.post_embeds.repeat(B, 1, 1), self.target_embeds.repeat(B, 1, 1)], dim=1
            )
            for i, kc in enumerate(self.prefix_cache.key_cache):
                self.prefix_cache.key_cache[i] = kc[:1, :, :T].expand(B, -1, -1, -1)
            for i, vc in enumerate(self.prefix_cache.value_cache):
                self.prefix_cache.value_cache[i] = vc[:1, :, :T].expand(B, -1, -1, -1)
            output = model(
                inputs_embeds=input_embeds,
                past_key_values=self.prefix_cache,
                use_cache=True,
            )
            for i, kc in enumerate(self.prefix_cache.key_cache):
                self.prefix_cache.key_cache[i] = kc[:1]
            for i, vc in enumerate(self.prefix_cache.value_cache):
                self.prefix_cache.value_cache[i] = vc[:1]
            self.prefix_cache.crop(T)
        else:
            input_embeds = torch.cat(
                [
                    self.pre_prompt_embeds.repeat(B, 1, 1),
                    optim_embeds,
                    self.post_embeds.repeat(B, 1, 1),
                    self.target_embeds.repeat(B, 1, 1),
                ],
                dim=1,
            )
            output = model(inputs_embeds=input_embeds)
        logits = output.logits

        # Shift logits so token n-1 predicts token n
        shift = input_embeds.shape[1] - self.target_ids.shape[1]
        shift_logits = logits[..., shift - 1 : -1, :].contiguous()  # (1, num_target_ids, vocab_size)
        shift_labels = self.target_ids.repeat(B, 1)

        loss = compute_loss(shift_logits, shift_labels, self.config.loss, self.config.mellowmax_alpha, self.not_allowed_ids)
        loss = loss.mean()

        optim_ids_onehot_grad = torch.autograd.grad(
            outputs=[loss],
            inputs=[optim_ids_onehot],
            create_graph=False,
            retain_graph=False
        )[0]
        return optim_ids_onehot_grad
