import copy
import logging
import random
import string
import sys
from functools import lru_cache
from typing import Callable, Literal

import torch
import torch.nn.functional as F
from torch.nn.utils.rnn import pad_sequence
from tqdm import tqdm
from transformers import DynamicCache, PreTrainedModel, PreTrainedTokenizerBase

from src.attacks.attack import Conversation
from src.io_utils import free_vram


@torch.no_grad()
def generate_ragged_batched(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    token_list: list[torch.IntTensor] | None = None,
    embedding_list: list[torch.FloatTensor] | None = None,
    initial_batch_size: int = 64,
    use_cache: bool = True,
    verbose: bool = False,
    **kwargs,
) -> list[list[str]]:
    """
    Generate completions for input_list, which can be either embeddings or tokens.
    Dynamically adjust batch size if an OOM error occurs.

    Args:
        model: The language model to use for generation.
        tokenizer: The tokenizer for the model.
        token_list: List of embeddings or tokens to generate from.
        embedding_list: List of embeddings or tokens to generate from.
        initial_batch_size: Starting batch size for generation.
        **kwargs: Additional arguments will be passed to `generate_ragged`.
    Returns:
        List of outputs generated by the model.
    """
    if token_list is not None and embedding_list is not None:
        raise ValueError("Only one of token_list or embedding_list should be provided.")
    input_type = "tokens" if token_list is not None else "embeddings"
    input_list = token_list if token_list is not None else embedding_list

    # Shorter sequences will come first to maximize batch size
    sorted_indexed_inputs = sorted(list(enumerate(input_list)), key=lambda x: x[1].size(0))
    sorted_input_list = [item for _, item in sorted_indexed_inputs]
    original_indices = [index for index, _ in sorted_indexed_inputs]

    def func(chunk):
        """Wrapper function to handle a single chunk of inputs."""
        return generate_ragged(
            model=model,
            tokenizer=tokenizer,
            token_list=chunk if input_type == "tokens" else None,
            embedding_list=chunk if input_type == "embeddings" else None,
            use_cache=use_cache,
            **kwargs,
        )
    sorted_outputs = with_max_batchsize(func, sorted_input_list, initial_batch_size=initial_batch_size, verbose=verbose)

    # Unsort the outputs to match the original input order
    outputs = [None] * len(input_list)
    for i, original_index in enumerate(original_indices):
        outputs[original_index] = sorted_outputs[i]
    return outputs


def with_max_batchsize(function: Callable, *inputs, initial_batch_size: int | None = None, verbose: bool = False):
    """
    Dynamically adjust batch size if an OOM error occurs.
    TODO: Try increasing batch size again if we have enough VRAM.

    Args:
        function: Callable
            A function that takes one or more arguments and returns a tensor or list of tensors.
            All input arguments should have the same length in their first dimension (batch dimension).
        *inputs:
            The inputs to pass to the function. All inputs must be tensors or lists and have the same length in their first dimension.
        initial_batch_size:
            Starting batch size for execution.
        verbose:
            Whether to print progress.
    Returns:
        The output of the function.
    """
    if not inputs:
        raise ValueError("At least one input must be provided")

    # Verify all inputs have the same length
    input_length = len(inputs[0])
    for i, inp in enumerate(inputs[1:], 1):
        if len(inp) != input_length:
            raise ValueError(f"All inputs must have the same length. Input 0 has length {input_length}, but input {i} has length {len(inp)}")

    outputs = []
    i = 0
    if initial_batch_size is None:
        initial_batch_size = input_length
    batch_size = min(initial_batch_size, input_length)
    pbar = tqdm(total=input_length, desc=f"Running function b={batch_size}", file=sys.stdout) if verbose else None

    while i < input_length:
        try:
            free_vram()
            # Create chunks for all inputs
            chunks = [inp[i:i + batch_size] for inp in inputs]
            output = function(*chunks)
            outputs.append(output)
            i += batch_size  # Move to the next batch
            if verbose:
                pbar.update(batch_size)
        except torch.cuda.OutOfMemoryError:
            # If we hit OOM, reduce batch size and retry the same chunk
            batch_size = batch_size // 2
            if verbose:
                pbar.set_description(f"Running function b={batch_size}")
            if batch_size < 1:
                raise RuntimeError(
                    "OOM even with batch_size=1; cannot generate further."
                )
    if verbose:
        pbar.close()

    if all(isinstance(x, torch.Tensor) for x in outputs):
        outputs = torch.cat(outputs, dim=0)
        assert len(outputs) == input_length
    elif all(isinstance(x, tuple) for x in outputs):
        # Transpose and concatenate tuple outputs
        # Handle both tensors and lists within tuples
        outputs_processed = []
        for i in range(len(outputs[0])):
            elements = [x[i] for x in outputs]
            if all(isinstance(e, torch.Tensor) for e in elements):
                outputs_processed.append(torch.cat(elements, dim=0))
            elif all(isinstance(e, list) for e in elements):
                outputs_processed.append([item for sublist in elements for item in sublist])
            else:
                types = ", ".join(f"{type(e).__name__}" for e in elements)
                raise TypeError(f"Wrapped functions may only return Tensors or lists, not {types}")
        outputs = tuple(outputs_processed)
        assert all(len(o) == input_length for o in outputs)
    elif all(isinstance(x, dict) for x in outputs):
        outputs = {k: [item for o in outputs for item in o[k]] for k in outputs[0].keys()}
        assert all(len(v) == input_length for v in outputs.values())
    else:
        outputs = [item for sublist in outputs for item in sublist]
        assert len(outputs) == input_length

    return outputs


@torch.no_grad
def generate_ragged(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    embedding_list: list[torch.FloatTensor] | None = None,
    token_list: list[torch.IntTensor] | None = None,
    max_new_tokens: int = 256,
    return_tokens: bool = False,
    padding_side: Literal["left", "right"] = "right",
    use_cache: bool = True,
    temperature: float = 0.0,
    top_p: float = 1.0,
    top_k: int = 0,
    num_return_sequences: int = 1,
) -> list[list[str]] | torch.Tensor:
    """
    Generate completions for multiple prompts in a single batch.
    No KV-cache for left-padding yet.
    Heavily tested across models to be close to individual generations.
    This is far from trivial due to various padding (left/right) and masking issues.
    The final function is still not identical to individual generations, but it is close.
    The reason for this is probably that attention masks typically don't use -inf for
    masked tokens, but instead have values like -65504 for float16.
    This can lead to small differences in the final logits and thus the generated tokens.
    We are much closer to individual generations than HF model.generate, which often
    fails in mysterious ways for LLama & Qwen models. Generations for CircuitBreakers
    often look weird, but are actually just what the model would output with `generate`
    as well.

    Number of generations that are the same as single-batch:
        Model Name                         This function    HF generate
        cais/zephyr_7b_r2d2                      100/100        100/100
        ContinuousAT/Llama-2-7B-CAT              100/100         39/100
        ContinuousAT/Phi-CAT                      90/100         95/100
        ContinuousAT/Zephyr-CAT                  100/100        100/100
        google/gemma-2-2b-it                      55/100         56/100
        meta-llama/Meta-Llama-3.1-8B-Instruct     62/100         11/100
        meta-llama/Llama-2-7b-chat-hf             88/100         30/100
        microsoft/Phi-3-mini-4k-instruct          53/100         50/100
        mistralai/Mistral-7B-Instruct-v0.3        83/100         79/100
        qwen/Qwen2-7B-Instruct                    78/100         19/100
        ---------------------------------------------------------------
        Total                                   809/1000       579/1000

    Args:
        model: A pretrained model.
        tokenizer: A pretrained tokenizer.
        embedding_list: list[torch.Tensor], optional
            A list of embeddings for each prompt. Should not be padded and can be of different lengths.
        token_list: list[torch.Tensor], optional
            A list of tokens for each prompt. Should not be padded and can be of different lengths.
        max_new_tokens: The maximum number of tokens to generate for each prompt. If -1, generate until EOS token.
    Returns:
        A list of completions for each prompt.
    """
    if (embedding_list is None) == (token_list is None):
        raise ValueError("One of embedding_list or token_list must be provided.")
    if embedding_list is not None:
        assert all(e.ndim == 2 for e in embedding_list), "Embeddings must be 2D."
        embedding_list = [e.to(model.device) for e in embedding_list]
    if token_list is not None:
        assert all(t.ndim == 1 for t in token_list), "Tokens must be 1D."
        token_list = [t.to(model.device) for t in token_list]
        embedding_list = [
            model.get_input_embeddings()(t.unsqueeze(0))[0] for t in token_list
        ]
    assert embedding_list is not None
    # TODO: Implement KV-caching for Gemma
    if use_cache and "gemma-2" in model.name_or_path:
        logging.warning("KV-cache not implemented for Gemma 2. Disabling cache.")
        use_cache = False

    B = len(embedding_list)

    def sample_next_token(logits: torch.Tensor) -> torch.Tensor:
        if temperature > 0.0:
            logits = logits / temperature
            if top_p < 1.0:
                logits = top_p_filtering(logits, top_p)
            if top_k > 0:
                logits = top_k_filtering(logits, top_k)
            next_tokens = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[:, 0]
        else:
            next_tokens = logits.argmax(dim=-1)
        return next_tokens

    if num_return_sequences == 0:
        if return_tokens:
            return torch.zeros((B, 0, max_new_tokens))
        else:
            return [[] for _ in range(B)]

    all_tokens = []
    for i in range(num_return_sequences):
        tokens = torch.full((B, max_new_tokens), tokenizer.pad_token_id)
        generation_completed = torch.zeros(B, dtype=torch.bool)
        if padding_side == "left":
            if use_cache:
                raise NotImplementedError("KV-cache not implemented for left padding.")
            # Add left padding
            embeddings = pad_sequence(
                [e.flip(0) for e in embedding_list], batch_first=True, padding_value=0
            ).flip(1)
            padded_embeddings = F.pad(embeddings, (0, 0, 0, max_new_tokens))
            # Create attention mask and position ids
            lengths = [
                {
                    "padding": embeddings.size(1) - e.size(0),
                    "generation": max_new_tokens - e.size(0),
                }
                for e in embedding_list
            ]
            attention_mask = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.ones([pl["generation"]])])
                    for pl in lengths
                ]
            ).to(model.device)
            position_ids = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.arange(pl["generation"])])
                    for pl in lengths
                ]
            ).long().to(model.device)
            next_token_idx = embeddings.size(1)
            for i in range(max_new_tokens):
                outputs = model(
                    inputs_embeds=padded_embeddings[:, :next_token_idx],
                    attention_mask=attention_mask[:, :next_token_idx],
                    position_ids=position_ids[:, :next_token_idx],
                )
                logits = outputs.logits[torch.arange(B), -1]
                next_tokens = sample_next_token(logits)
                padded_embeddings[torch.arange(B), next_token_idx] = (
                    model.get_input_embeddings()(next_tokens).detach()
                )
                tokens[:, i] = next_tokens.cpu()
                generation_completed |= next_tokens.cpu() == tokenizer.eos_token_id
                if generation_completed.all():
                    logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                    break
                next_token_idx += 1
        elif padding_side == "right":
            # Add right padding
            embeddings = pad_sequence(
                [e for e in embedding_list], batch_first=True, padding_value=0
            )
            padded_embeddings = F.pad(embeddings, (0, 0, 0, max_new_tokens))
            next_token_idx = torch.tensor([e.size(0) for e in embedding_list])

            if use_cache:
                # This is the hot path so we have additional optimizations here.
                # As we generate tokens, we keep track of which prompts are completed,
                # and only generate tokens for the active prompts.
                # This is slightly (~20%) slower if all prompts have similar length,
                # but faster if prompts have different lengths and **saves VRAM**.

                # Fill prefix cache
                past_key_values = DynamicCache()
                if next_token_idx.min() > 1:
                    model(
                        inputs_embeds=padded_embeddings[:, : next_token_idx.min() - 1],
                        past_key_values=past_key_values,
                        use_cache=True,
                    )
                for i in range(max_new_tokens):
                    # Caching with right padding is a bit tricky:
                    # We have to feed more than one token at each forward pass :(.
                    # Instead, we feed a 'window' from the last token of the shortest prompt
                    # to the last token of the longest prompt.
                    # This means that caching works best when sequences have similar length.
                    active_mask = ~generation_completed
                    active_mask_idx = torch.arange(B)[active_mask]
                    active_embeddings = padded_embeddings[active_mask, next_token_idx.min() - 1 : next_token_idx.max()]
                    logits = model(
                        inputs_embeds=active_embeddings,
                        past_key_values=past_key_values,
                        use_cache=True,
                    ).logits

                    next_tokens = torch.full((B,), tokenizer.eos_token_id, device=model.device)
                    next_token_idx_active = next_token_idx[active_mask]
                    logits = logits[
                        torch.arange(logits.size(0)),
                        next_token_idx_active - next_token_idx.min()
                    ]
                    next_tokens[active_mask] = sample_next_token(logits)
                    padded_embeddings[active_mask_idx, next_token_idx_active] = model.get_input_embeddings()(next_tokens[active_mask])
                    tokens[:, i] = next_tokens.cpu()
                    continue_mask = (next_tokens.cpu() != tokenizer.eos_token_id)[active_mask]
                    # have to manually crop the past_key_values to the correct length
                    # since we only add a single step at a time
                    for j in range(len(past_key_values.key_cache)):
                        past_key_values.key_cache[j] = past_key_values.key_cache[j][continue_mask, :, :next_token_idx.min()]
                        past_key_values.value_cache[j] = past_key_values.value_cache[j][continue_mask, :, :next_token_idx.min()]

                    generation_completed |= next_tokens.cpu() == tokenizer.eos_token_id
                    if generation_completed.all():
                        logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                        break
                    next_token_idx += 1
            else:
                for i in range(max_new_tokens):
                    outputs = model(inputs_embeds=padded_embeddings[:, : next_token_idx.max()])
                    logits = outputs.logits[torch.arange(B), next_token_idx - 1]
                    next_tokens = sample_next_token(logits)
                    padded_embeddings[torch.arange(B), next_token_idx] = (
                        model.get_input_embeddings()(next_tokens).detach()
                    )
                    tokens[:, i] = next_tokens.cpu()
                    generation_completed |= next_tokens.cpu() == tokenizer.eos_token_id
                    if generation_completed.all():
                        logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                        break
                    next_token_idx += 1
        else:
            raise ValueError(f"Unknown padding_side: {padding_side}")
        all_tokens.append(tokens)

    all_tokens = torch.stack(all_tokens, dim=1)  # (B, N, T)
    if return_tokens:
        return all_tokens
    completion = [tokenizer.batch_decode(all_tokens[i], skip_special_tokens=False) for i in range(B)]
    completion = [[c.split(tokenizer.eos_token)[0] for c in completion[i]] for i in range(B)]
    return completion


@torch.no_grad
def get_losses_batched(
    model: PreTrainedModel,
    targets: list[torch.Tensor],
    embedding_list: list[torch.Tensor] | None = None,
    token_list: list[torch.Tensor] | None = None,
    padding_side: Literal["left", "right"] = "right",
    initial_batch_size: int | None = None,
) -> list[torch.Tensor]:
    """
    Get per-timestep losses for multiple ragged prompts in a single batch.
    No KV-cache for now.

    Args:
        model: A pretrained model.
        targets: A list of 1D tensors containing the target tokens for each prompt.
        embedding_list: list[torch.Tensor], optional
            A list of embeddings for each prompt. Should not be padded and can be of different lengths.
        token_list: list[torch.Tensor], optional
            A list of tokens for each prompt. Should not be padded and can be of different lengths.
        max_new_tokens: The maximum number of tokens to generate for each prompt.
    Returns:
        A list of completions for each prompt.
    """
    if (embedding_list is None) == (token_list is None):
        raise ValueError("Either embedding_list or token_list must be provided.")
    if embedding_list is not None:
        assert all(e.ndim == 2 for e in embedding_list), "Embeddings must be 2D."
        embedding_list = [e.to(model.device) for e in embedding_list]
    if token_list is not None:
        assert all(t.ndim == 1 for t in token_list), "Tokens must be 1D."
        token_list = [t.to(model.device) for t in token_list]
        embedding_list = [
            model.get_input_embeddings()(t.unsqueeze(0))[0] for t in token_list
        ]
    assert embedding_list is not None
    assert all(t.ndim == 1 for t in targets), "Targets must be 1D."

    def get_losses_func(embedding_list, targets):
        # We first pad the embeddings to the maximum context length of the model.
        B = len(embedding_list)
        if padding_side == "left":
            print("Warning: Padding side 'left' is not recommended for get_batched_losses as it may yield nans.")
            # Add left padding
            embeddings = pad_sequence(
                [e.flip(0) for e in embedding_list], batch_first=True, padding_value=0
            ).flip(1)

            targets_padded = pad_sequence(
                [t.flip(0) for t in targets], batch_first=True, padding_value=0
            ).flip(1)
            # Create attention mask and position ids
            lengths = [
                {
                    "padding": embeddings.size(1) - e.size(0),
                    "generation": e.size(0),
                }
                for e in embedding_list
            ]
            attention_mask = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.ones([pl["generation"]])])
                    for pl in lengths
                ]
            ).to(model.device)
            position_ids = torch.stack(
                [
                    torch.cat(
                        [torch.zeros(pl["padding"]), torch.arange(pl["generation"])]
                    )
                    for pl in lengths
                ]
            ).long().to(model.device)

            outputs = model(
                inputs_embeds=embeddings,
                attention_mask=attention_mask,
                position_ids=position_ids,
            ).logits
            losses = F.cross_entropy(
                outputs.reshape(-1, outputs.size(-1)),
                targets_padded.view(-1).to(outputs.device),
                reduction="none",
            )
            losses = losses.view(B, -1)
            losses = [losses[i, -t.size(0) : -1] for i, t in enumerate(targets)]
        elif padding_side == "right":
            # Add right padding
            embeddings = pad_sequence(
                [e for e in embedding_list], batch_first=True, padding_value=0
            )
            targets_padded = pad_sequence(
                [t for t in targets], batch_first=True, padding_value=0
            )
            outputs = model(inputs_embeds=embeddings).logits
            losses = F.cross_entropy(
                outputs.reshape(-1, outputs.size(-1)),
                targets_padded.view(-1).to(outputs.device),
                reduction="none",
            )
            losses = losses.view(B, -1)
            losses = [losses[i, : t.size(0) - 1] for i, t in enumerate(targets)]
        else:
            raise ValueError(f"Unknown padding_side: {padding_side}")
        return losses

    if initial_batch_size is None:
        initial_batch_size = len(embedding_list)
    losses = with_max_batchsize(get_losses_func, embedding_list, targets, initial_batch_size=initial_batch_size)
    return losses


def prepare_tokens(
    tokenizer: PreTrainedTokenizerBase,
    prompt: str,
    target: str = "",
    attack: str | None = None,
    placement: Literal["prompt", "suffix"] = "suffix",
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """For many attacks, we need to figure out how exactly to tokenize the input.
    Since only some models add a space or various control tokens, we have to figure
    out the exact format. We want to make sure that the first generated token is
    exactly 'Sure', and not a space or control token.

    We thus chunk the sequence into the following 5 parts (some of which may be empty):

    [PRE] + [Prompt] + [Attack] + [POST] + [Target]

    Treating prompt and attack separately is important for optimization, as we only
    want to optimize the attack part.

    Parameters:
    - tokenizer: The tokenizer to use.
    - prompt: The prompt string to use.
    - target: The target string to use.
    - attack: The attack string to use.
    - placement: Where to place the attack. Can be either "prompt" or "suffix".

    Returns:
    - pre_tokens: The tokens before the prompt.
    - prompt_tokens: The tokens of the prompt.
    - attack_tokens: The tokens of the attack. <- optimize these
    - post_tokens: The tokens after the attack.
    - target_tokens: The tokens of the target string. <- apply loss here
    """
    if placement == "prompt":
        attack, prompt = prompt, ""
    elif attack is None:
        raise ValueError("If placement is 'suffix', attack must be provided.")

    # Some tokenizers and templates (e.g., allenai/Llama-3.1-Tulu-3-8B-DPO) need more
    # messages because their tokenization is more likely to have weird splits.
    for num_messages in [100, 1000, 10000]:
        pre_tokens, post_tokens, suffix_tokens = get_pre_post_suffix_tokens(tokenizer, num_messages)
        # Now we look at the actual chat by the user
        chat = [
            {"role": "user", "content": prompt + attack},
            {"role": "assistant", "content": target},
        ]
        tokenized_together = tokenize_chats([chat], tokenizer)[0]
        # We now cut the tokenized sequence into parts step-by-step.
        # First, we remove the prefix and suffix tokens, as we already know the prefix and
        # don't neeed the suffix.
        prompt_attack_post_target = tokenized_together[len(pre_tokens) : -len(suffix_tokens)]
        # We now look for the post tokens. These are between [prompt + attack] and [target].

        # Now, we cut out sliding views from the remaining tokens and check if they match the post tokens.
        sliding_windows = torch.stack([
            prompt_attack_post_target[i:i+len(post_tokens)]
            for i in range(len(prompt_attack_post_target) - len(post_tokens) + 1)
        ])

        # Compare each window with post_tokens
        matches = torch.all(sliding_windows == post_tokens, dim=1)
        # Find the first match index
        match_indices = torch.where(matches)[0]

        if len(match_indices) > 0:
            # Get the first match position
            i = match_indices[0].item()
            prompt_attack_tokens = prompt_attack_post_target[:i]
            target_tokens = prompt_attack_post_target[i + len(post_tokens):]
            break
    else:
        raise ValueError(
            f"Unable to find consistent tokenizer patterns for {tokenizer.name_or_path}"
        )

    tokenized_together_no_attack = get_tokenized_no_attack(prompt, target, tokenizer)

    attack_length = len(tokenized_together) - len(tokenized_together_no_attack)

    # OPTIMIZATION: Use direct indexing instead of tensor_split if possible
    prompt_tokens, attack_tokens = torch.tensor_split(
        prompt_attack_tokens, [prompt_attack_tokens.size(0)-attack_length]
    )
    if "llama-2" in tokenizer.name_or_path.lower():
        # LLama 2 models have incorrect templating and need to be fixed manually
        post_tokens = torch.cat([post_tokens, torch.tensor([29871])])

    return pre_tokens, prompt_tokens, attack_tokens, post_tokens, target_tokens


TOKENIZER_CACHE = {}


class TokenMergeError(Exception):
    """
    Exception raised when a merge error occurs.
    """
    pass


def prepare_conversation(
    tokenizer: PreTrainedTokenizerBase,
    conversation: Conversation,
    conversation_opt: Conversation | None = None,
) -> list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]]:
    """For many attacks, we need to figure out how exactly to tokenize the input.
    Since only some models want a space or various control tokens, we have to figure
    out the exact format. We want to make sure that the first generated token is
    exactly 'Sure', and not a space or control token.

    We thus chunk each back-and-forth message pair into the following parts
    (some of which may be empty):

    [PRE] + [Attack_Prefix0] + [Prompt0] + [Attack_Suffix0] + [POST0] + [Target0] +
    [SEP] + [Attack_Prefix1] + [Prompt1] + [Attack_Suffix1] + [POST1] + [Target1] +
    [SEP] + [Attack_Prefix2] + [Prompt2] + [Attack_Suffix2] + [POST2] + [Target2] ...

    Treating prompt and attack separately is important for optimization, as we only
    want to optimize the attack part.

    Parameters:
    - tokenizer: The tokenizer to use.
    - conversation: The conversation to use. Last message must be assistant message.
    - conversation_opt: The conversation to use for the attack. Parts of the string in
                        this conversation that are not in conversation are used as attack.

    Returns:
    - pre_tokens: The tokens before the prompt.
    - attack_prefix_tokens: The tokens of the attack prefix. <- optimize these
    - prompt_tokens: The tokens of the prompt.
    - attack_suffix_tokens: The tokens of the attack suffix. <- optimize these
    - post_tokens: The tokens after the attack.
    - target_tokens: The tokens of the target string. <- apply loss here
    """
    assert conversation[-1]["role"] == "assistant", "Last message must be assistant message."
    if conversation_opt is None:
        conversation_opt = copy.deepcopy(conversation)

    def get_common_prefix_len(tokens: list[torch.Tensor]) -> int:
        max_length = max(t.size(0) for t in tokens)
        tokens = [F.pad(t, (0, max_length - t.size(0)), value=-1) for t in tokens]
        tokens = torch.stack(tokens, dim=0)
        common_tokens = torch.all(tokens == tokens[:1, :], dim=0)
        common_prefix_len = 0
        while common_prefix_len < common_tokens.size(0) and common_tokens[common_prefix_len]:
            common_prefix_len += 1
        return common_prefix_len

    def get_common_suffix_len(tokens: list[torch.Tensor]) -> int:
        max_length = max(t.size(0) for t in tokens)
        tokens = [F.pad(t, (max_length - t.size(0), 0), value=-1) for t in tokens]
        tokens = torch.stack(tokens, dim=0)
        common_tokens = torch.all(tokens == tokens[:1, :], dim=0)
        common_suffix_len = 0
        while common_suffix_len < common_tokens.size(0) and common_tokens[common_tokens.size(0) - common_suffix_len - 1]:
            common_suffix_len += 1
        return common_suffix_len

    n_random_strings = 8  # Lower numbers are faster, but yield incorrect results. P_incorrect~(1/24)^n_random_strings
    out_tokens = []
    n_tokenized_clean = 0
    n_tokenized_attack = 0
    n_turns = len(conversation)

    if conversation[0]["role"] == "system":
        start_idx = 2
    else:
        start_idx = 1

    for i in range(start_idx, n_turns, 2):
        # We work our way through the conversation, section by section.

        # First, lets get the tokens before the user message.
        # For this, we replace the user message with random strings and find the common pre- and suffix.
        # Sadly this cannot be cached as the suffix and sep length depends on the position in the conversation
        empty_convs = [copy.deepcopy(conversation[:i]) for _ in range(n_random_strings)]
        for conv in empty_convs:
            conv[-1]["content"] = generate_random_string(5)
        tokenized_empty = tokenize_chats(empty_convs, tokenizer)
        sep_len = get_common_prefix_len(tokenized_empty)
        common_suffix_len = get_common_suffix_len(tokenized_empty)

        sep = tokenized_empty[0][n_tokenized_clean:sep_len]
        n_tokenized_clean += sep.size(0)
        n_tokenized_attack += sep.size(0)

        # Now the user message itself.
        # Here we have to also take into account the prefix and suffix attack tokens.
        tokenized_clean = tokenize_chats([conversation[:i]], tokenizer)[0][n_tokenized_clean:]
        tokenized_attack = tokenize_chats([conversation_opt[:i]], tokenizer)[0][n_tokenized_attack:]
        if common_suffix_len > 0:
            tokenized_clean = tokenized_clean[:-common_suffix_len]
            tokenized_attack = tokenized_attack[:-common_suffix_len]
        for j in range(len(tokenized_attack)-len(tokenized_clean)+1):
            if torch.equal(tokenized_attack[j:j+len(tokenized_clean)], tokenized_clean):
                prompt = tokenized_attack[j:j+len(tokenized_clean)]
                break
        else:
            raise TokenMergeError(
                "There are tokenizer merges across prompt and attack, cannot split.\n"
                + f"Prompt: {conversation[:i]}\n"
                + f"Attack: {conversation_opt[:i]}\n"
                + f"{tokenized_clean}\n"
                + f"{tokenized_attack}"
            )
        pre_attack = tokenized_attack[:j]
        suf_attack = tokenized_attack[j+len(tokenized_clean):]
        n_tokenized_clean += prompt.size(0)
        n_tokenized_attack += pre_attack.size(0) + prompt.size(0) + suf_attack.size(0)

        # Done with user message, now time for assistant message
        if tokenizer not in TOKENIZER_CACHE:
            empty_convs = [copy.deepcopy(conversation[:i+1]) for _ in range(n_random_strings)]
            for conv in empty_convs:
                conv[-1]["content"] = generate_random_string(5)
            tokenized_empty = tokenize_chats(empty_convs, tokenizer)
            tokenized_empty = [t[n_tokenized_clean:] for t in tokenized_empty if t.size(0) > 0]
            post_len = get_common_prefix_len(tokenized_empty)
            suffix_len = get_common_suffix_len(tokenized_empty)
            TOKENIZER_CACHE[tokenizer] = post_len, suffix_len
        else:
            post_len, suffix_len = TOKENIZER_CACHE[tokenizer]

        tokenized_clean = tokenize_chats([conversation[:i+1]], tokenizer)[0]
        post = tokenized_clean[n_tokenized_clean:n_tokenized_clean+post_len]
        n_tokenized_clean += post.size(0)
        n_tokenized_attack += post.size(0)
        if "llama-2" in tokenizer.name_or_path.lower():
            # LLama 2 models have incorrect templating and need to be fixed manually
            post = torch.cat([post, torch.tensor([29871])])
            if sep[0] == 29871:
                sep = sep[1:]
        elif "gemma-2" in tokenizer.name_or_path.lower():
            if i != start_idx:
                t = torch.tensor([235248,    108])
                sep = torch.cat([t, sep])
        target = tokenized_clean[n_tokenized_clean:-suffix_len]
        n_tokenized_clean += target.size(0)
        n_tokenized_attack += target.size(0)
        out_tokens.append([sep, pre_attack, prompt, suf_attack, post, target])
    return out_tokens


def generate_random_string(k: int = 5) -> str:
    chars = string.ascii_letters + string.digits + " "

    return "".join(random.choices(chars, k=k))


def _make_random_chats(n: int, k: int = 5) -> list[Conversation]:
    """Generate n random chat conversations with k-length messages.

    Returns:
        List of chat conversations, where each conversation is a list of
        user/assistant message dictionaries with random content.
    """
    chats = []
    for _ in range(n):
        chat = [
            {"role": "user", "content": generate_random_string(k)},
            {"role": "assistant", "content": generate_random_string(k)},
        ]
        chats.append(chat)

    return chats


def _extract_prefix_middle_suffix(vectors: list[torch.Tensor]) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
    def longest_common_prefix(sequences):
        if not sequences:
            return []
        prefix = sequences[0]
        for seq in sequences[1:]:
            min_len = min(len(prefix), len(seq))
            i = 0
            while i < min_len and prefix[i] == seq[i]:
                i += 1
            prefix = prefix[:i]
            if not prefix:
                return []
        return prefix

    def longest_common_suffix(sequences):
        if not sequences:
            return []
        suffix = sequences[0]
        for seq in sequences[1:]:
            min_len = min(len(suffix), len(seq))
            i = 1
            while i <= min_len and suffix[-i] == seq[-i]:
                i += 1
            if i > 1:
                suffix = suffix[-(i - 1) :]
            else:
                return []
        return suffix

    def longest_common_subsequence(sequences):
        if not sequences:
            return []
        reference = sequences[0]
        n = len(reference)
        # Start with the longest possible substrings and decrease length
        for length in range(n, 0, -1):
            for start in range(n - length + 1):
                candidate = reference[start : start + length]
                if all(
                    any(
                        candidate == seq[i : i + length]
                        for i in range(len(seq) - length + 1)
                    )
                    for seq in sequences[1:]
                ):
                    return candidate
        return []

    sequences = [vec.tolist() for vec in vectors]
    prefix = longest_common_prefix(sequences)
    suffix = longest_common_suffix(sequences)
    # Trim the prefix and suffix from sequences
    sequences_trimmed = [
        seq[len(prefix) : len(seq) - len(suffix) if len(suffix) > 0 else None]
        for seq in sequences
    ]
    middle = longest_common_subsequence(sequences_trimmed)
    return torch.tensor(prefix), torch.tensor(middle), torch.tensor(suffix)


def tokenize_chats(chats: list[Conversation], tokenizer) -> list[torch.Tensor]:
    templates = tokenizer.apply_chat_template(
        chats, tokenize=False, add_generation_prompt=False
    )
    # Sometimes, the chat template adds the BOS token to the beginning of the template.
    # The tokenizer adds it again later, so we need to remove it to avoid duplication.
    if tokenizer.bos_token:
        for i, template in enumerate(templates):
            templates[i] = template.removeprefix(tokenizer.bos_token)

    # have to torchify individually because results may be different lengths
    return [torch.tensor(t) for t in tokenizer(templates, add_special_tokens=True).input_ids]


@lru_cache()
def get_tokenized_no_attack(prompt, target, tokenizer):
    # Cache the tokenization of the chat without attack, cause it changes rarely for
    # most attacks.
    chat_no_attack = [
        {"role": "user", "content": prompt},
        {"role": "assistant", "content": target},
    ]
    return tokenize_chats([chat_no_attack], tokenizer)[0]


# Generate random messages to find tokenizer patterns, this is ugly but fast
@lru_cache()
def get_pre_post_suffix_tokens(tokenizer, num_messages):
    test_chats = _make_random_chats(num_messages)
    test_tokenized = tokenize_chats(test_chats, tokenizer)
    return _extract_prefix_middle_suffix(test_tokenized)


def top_p_filtering(logits: torch.Tensor, top_p: float) -> torch.Tensor:
    """Filter logits using nucleus (top-p) sampling.

    Parameters
    ----------
    logits: torch.Tensor, shape (B, T, V) or (B, V)
        The logits to filter.
    top_p: float
        The top-p threshold.

    Returns
    -------
    torch.Tensor
    """
    single_token_only = logits.ndim == 2
    if single_token_only:
        logits = logits.unsqueeze(1)
    sorted_logits, sorted_indices = torch.sort(logits, descending=True)
    cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

    # Remove tokens with cumulative probability above the threshold
    sorted_indices_to_remove = cumulative_probs > top_p
    # Shift indices to the right to keep also the first token above the threshold
    sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
    sorted_indices_to_remove[..., 0] = 0

    # Scatter sorted tensors to original indexing
    indices_to_remove = sorted_indices_to_remove.scatter(
        -1, sorted_indices, sorted_indices_to_remove
    )
    logits[indices_to_remove] = float('-inf')
    if single_token_only:
        logits = logits.squeeze(1)
    return logits


def top_k_filtering(logits: torch.Tensor, top_k: int) -> torch.Tensor:
    """Filter logits using top-k sampling.

    Parameters
    ----------
    logits: torch.Tensor, shape (B, T, V) or (B, V)
        The logits to filter.
    top_k: int
        The top-k threshold.

    Returns
    -------
    torch.Tensor
        Filtered logits with values below top-k threshold set to -inf.
    """
    single_token_only = logits.ndim == 2
    if single_token_only:
        logits = logits.unsqueeze(1)

    values, _ = torch.topk(logits, top_k)
    # Get minimum value of top-k tokens
    min_values = values[..., -1, None]
    # Zero out everything below min values
    logits[logits < min_values] = float('-inf')

    if single_token_only:
        logits = logits.squeeze(1)
    return logits


def get_disallowed_ids(tokenizer: PreTrainedTokenizerBase, allow_non_ascii: bool, allow_special: bool) -> torch.Tensor:
    disallowed_ids = set()

    def is_ascii(s):
        return s.isascii() and s.isprintable()

    # Important to loop over len(tokenizer), not just tokenizer.vocab_size, because
    # special tokens added post-hoc are not counted to vocab_size.
    if not allow_non_ascii:
        for i in range(len(tokenizer)):
            if not is_ascii(tokenizer.decode([i])):
                disallowed_ids.add(i)

    if not allow_special:
        for i in range(len(tokenizer)):
            if not tokenizer.decode([i], skip_special_tokens=True):
                disallowed_ids.add(i)

        is_gemma_2 = "gemma-2" in tokenizer.name_or_path.lower()
        if is_gemma_2:
            disallowed_ids.add(tokenizer.convert_tokens_to_ids("[@BOS@]"))
            for i in range(100):
                disallowed_ids.add(tokenizer.convert_tokens_to_ids(f"<unused{i}>"))

    if tokenizer.bos_token_id is not None:
        disallowed_ids.add(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        disallowed_ids.add(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        disallowed_ids.add(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        disallowed_ids.add(tokenizer.unk_token_id)

    disallowed_ids = sorted(list(disallowed_ids))
    return torch.tensor(disallowed_ids)


def filter_suffix(
    tokenizer: PreTrainedTokenizerBase,
    clean_conversation: Conversation,
    ids: list[list[torch.Tensor | None, torch.Tensor | None]]
) -> list[int]:
    """
    Filters out sequences of token ids that are not invariant under decode-encode round trip.

    Example usage:
    >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.1-8B-Instruct")
    >>> clean_conversation = [
        {"role": "user", "content": "Hello, how are you?"},
        {"role": "assistant", "content": "I'm doing well, thank you!"},
        {"role": "user", "content": "What is the capital of France?"},
        {"role": "assistant", "content": "The capital of France is Paris."}
    ]
    >>> prefix_ids_turn0 = torch.randint(1000, 2000, (512, 10))
    >>> suffix_ids_turn0 = torch.randint(1000, 2000, (512, 10))
    >>> prefix_ids_turn1 = None
    >>> suffix_ids_turn1 = torch.empty((512, 0))
    >>> ids = [[prefix_ids_turn0, suffix_ids_turn0], [prefix_ids_turn1, suffix_ids_turn1]]
    >>> filter_suffix(tokenizer, clean_conversation, ids)

    Parameters
    ----------
    tokenizer : PreTrainedTokenizerBase
        The tokenizer to use.
    clean_conversation : list of dicts
        Each dict contains {"role": ..., "content": ...}.
    ids : list of list of tensors
        Outer list indexed by conversation turn index.
        Each inner list should contain [prefix_tensor, suffix_tensor],
        both shaped (search_width, n_optim_ids).
        If the turn has no prefix or suffix, the corresponding item can also be None.
    Returns
    -------
    retain_idx : List[int]
        Indices into the search dimension where token ids are stable under decode/encode.
    """
    # Structural assertions
    assert all(len(turn_ids) == 2 for turn_ids in ids), "Each conversation turn must contain [prefix, suffix]."
    search_width = max(
        max(t.size(0) if t is not None else 0 for t, _ in ids),
        max(t.size(0) if t is not None else 0 for _, t in ids)
    )
    n_turns = len(clean_conversation)
    # Decode all ids
    decoded_tokens: list[tuple[list[str], list[str]]] = []
    for turn_prefix, turn_suffix in ids:
        prefix_decoded = tokenizer.batch_decode(turn_prefix) if turn_prefix is not None else [""] * search_width
        suffix_decoded = tokenizer.batch_decode(turn_suffix) if turn_suffix is not None else [""] * search_width
        decoded_tokens.append((prefix_decoded, suffix_decoded))

    retain_idx = []
    for i in range(search_width):
        conversation = []
        for j in range(n_turns):
            content = clean_conversation[j]["content"]
            if j % 2 == 0:
                conversation.append({"role": "user", "content": decoded_tokens[j//2][0][i] + content + decoded_tokens[j//2][1][i]})
            else:
                conversation.append({"role": "assistant", "content": content})
        try:
            recon_ids = prepare_conversation(tokenizer, clean_conversation, conversation)
        except TokenMergeError:
            continue

        prefix_match = all([torch.equal(ids[j][0][i] if ids[j][0] is not None else torch.empty(0), recon_ids[j][1]) for j in range(len(recon_ids))])
        suffix_match = all([torch.equal(ids[j][1][i] if ids[j][1] is not None else torch.empty(0), recon_ids[j][3]) for j in range(len(recon_ids))])
        if prefix_match and suffix_match:
            retain_idx.append(i)

    if not retain_idx:
        # This occurs in some cases, e.g. using the Llama-3 tokenizer with a bad initialization
        raise RuntimeError(
            "No token sequences are the same after decoding and re-encoding. "
            "Consider setting `filter_ids=False` or trying a different `optim_str_init`.\n"
            "Here's an example of the token sequence that failed:\n"
            f"{ids[-1]}"
            "\n->\n"
            f"{decoded_tokens[-1]}"
            "\n->\n"
            f"{recon_ids[-1][1], recon_ids[-1][3]}"
        )
    return retain_idx
