"""
MIT License

Copyright (c) 2024 Gray Swan AI

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"""

import copy
import gc
import os  
import json
import logging
import queue
import threading

from dataclasses import dataclass
from typing import List, Optional, Tuple, Union

from tqdm import tqdm

import pandas as pd
import torch.nn.functional as F  
import torch
import transformers
from torch import Tensor
from transformers import set_seed
from scipy.stats import spearmanr

from nanogcg.utils import (
    INIT_CHARS,
    configure_pad_token,
    find_executable_batch_size,
    get_nonascii_toks,
    mellowmax,
)

logger = logging.getLogger("nanogcg")
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(asctime)s [%(filename)s:%(lineno)d] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)


@dataclass
class ProbeSamplingConfig:
    draft_model: transformers.PreTrainedModel
    draft_tokenizer: transformers.PreTrainedTokenizer
    r: int = 8
    sampling_factor: int = 16


@dataclass
class GCGConfig:
    num_steps: int = 500
    optim_str_init: Union[str, List[str]] = (
        "x x x x x x x x x x x x x x x x x x x x "
        + "x x x x x x x x x x x x x x x x x x x x "
        + "x x x x x x x x x x x x x x x x x x x x"
    )
    search_width: int = 128  # 1024
    batch_size: int = None
    topk: int = 256
    n_replace: int = 1
    buffer_size: int = 0
    use_mellowmax: bool = False
    mellowmax_alpha: float = 1.0
    early_stop: bool = False
    use_prefix_cache: bool = False  # Was True, set to False to run
    allow_non_ascii: bool = False
    filter_ids: bool = True
    add_space_before_target: bool = False  # Was False, set to True to run
    seed: int = 42
    verbosity: str = "INFO"
    probe_sampling_config: Optional[ProbeSamplingConfig] = None


@dataclass
class GCGResult:
    best_loss: float
    best_string: str
    losses: List[float]
    strings: List[str]


class AttackBuffer:
    def __init__(self, size: int):
        self.buffer = []  # elements are (loss: float, optim_ids: Tensor)
        self.size = size

    def add(self, loss: float, optim_ids: Tensor) -> None:
        if self.size == 0:
            self.buffer = [(loss, optim_ids)]
            return

        if len(self.buffer) < self.size:
            self.buffer.append((loss, optim_ids))
        else:
            self.buffer[-1] = (loss, optim_ids)

        self.buffer.sort(key=lambda x: x[0])

    def get_best_ids(self) -> Tensor:
        return self.buffer[0][1]

    def get_lowest_loss(self) -> float:
        return self.buffer[0][0]

    def get_highest_loss(self) -> float:
        return self.buffer[-1][0]

    def log_buffer(self, tokenizer):
        message = "buffer:"
        for loss, ids in self.buffer:
            optim_str = tokenizer.batch_decode(ids)[0]
            optim_str = optim_str.replace("\\", "\\\\")
            optim_str = optim_str.replace("\n", "\\n")
            message += f"\nloss: {loss}" + f" | string: {optim_str}"
        logger.info(message)


def sample_ids_from_grad(
    ids: Tensor,
    grad: Tensor,
    search_width: int,
    topk: int = 256,
    n_replace: int = 1,
    not_allowed_ids: Tensor = False,
):
    """Returns `search_width` combinations of token ids based on the token gradient.

    Args:
        ids : Tensor, shape = (n_optim_ids)
            the sequence of token ids that are being optimized
        grad : Tensor, shape = (n_optim_ids, vocab_size)
            the gradient of the GCG loss computed with respect to the one-hot token embeddings
        search_width : int
            the number of candidate sequences to return
        topk : int
            the topk to be used when sampling from the gradient
        n_replace : int
            the number of token positions to update per sequence
        not_allowed_ids : Tensor, shape = (n_ids)
            the token ids that should not be used in optimization

    Returns:
        sampled_ids : Tensor, shape = (search_width, n_optim_ids)
            sampled token ids
    """
    n_optim_tokens = len(ids)
    original_ids = ids.repeat(search_width, 1)

    if not_allowed_ids is not None:
        grad[:, not_allowed_ids.to(grad.device)] = float("inf")

    topk_ids = (-grad).topk(topk, dim=1).indices

    sampled_ids_pos = torch.argsort(torch.rand((search_width, n_optim_tokens), device=grad.device))[..., :n_replace]
    sampled_ids_val = torch.gather(
        topk_ids[sampled_ids_pos],
        2,
        torch.randint(0, topk, (search_width, n_replace, 1), device=grad.device),
    ).squeeze(2)

    new_ids = original_ids.scatter_(1, sampled_ids_pos, sampled_ids_val)

    return new_ids


def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer):
    """Filters out sequeneces of token ids that change after retokenization.

    Args:
        ids : Tensor, shape = (search_width, n_optim_ids)
            token ids
        tokenizer : ~transformers.PreTrainedTokenizer
            the model's tokenizer

    Returns:
        filtered_ids : Tensor, shape = (new_search_width, n_optim_ids)
            all token ids that are the same after retokenization
    """
    ids_decoded = tokenizer.batch_decode(ids)
    filtered_ids = []

    for i, _ in enumerate(ids_decoded):
        # Retokenize the decoded token ids
        ids_encoded = tokenizer(ids_decoded[i], return_tensors="pt", add_special_tokens=False).to(ids.device)[
            "input_ids"
        ][0]
        if torch.equal(ids[i], ids_encoded):
            filtered_ids.append(ids[i])

    if not filtered_ids:
        # This occurs in some cases, e.g. using the Llama-3 tokenizer with a bad initialization
        raise RuntimeError(
            "No token sequences are the same after decoding and re-encoding. "
            "Consider setting `filter_ids=False` or trying a different `optim_str_init`"
        )

    return torch.stack(filtered_ids)


class GCG:
    def __init__(
        self,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        config: GCGConfig,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

        self.embedding_layer = model.get_input_embeddings()
        self.not_allowed_ids = None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device)
        self.prefix_cache = None
        self.draft_prefix_cache = None

        self.stop_flag = False
        self.patient_time = 5

        self.epoch = 0
        self.run_number = 0
        self.saving_path = ""
        self.expected_target_str = ""
        self.first_difference = 5
        self.expected_token_str = ""
        self.target_token_str = ""
        self.is_verbose = False
        self.target_ids = None
        self.expected_target = None
        self.before_embeds = None
        self.after_embeds = None
        self.target_embeds = None
        self.draft_before_embeds = None
        self.draft_after_embeds = None
        self.draft_target_embeds = None
        self.draft_target_ids = None

        self.draft_model = None
        self.draft_tokenizer = None
        self.draft_embedding_layer = None
        if self.config.probe_sampling_config:
            self.draft_model = self.config.probe_sampling_config.draft_model
            self.draft_tokenizer = self.config.probe_sampling_config.draft_tokenizer
            self.draft_embedding_layer = self.draft_model.get_input_embeddings()
            if self.draft_tokenizer.pad_token is None:
                configure_pad_token(self.draft_tokenizer)

        if model.dtype in (torch.float32, torch.float64):
            logger.warning(
                "Model is in %s. Use a lower precision data type, if possible, for much faster optimization.",
                model.dtype,
            )

        if model.device == torch.device("cpu"):
            logger.warning("Model is on the CPU. Use a hardware accelerator for faster optimization.")

        if not tokenizer.chat_template:
            logger.warning(
                "Tokenizer does not have a chat template. Assuming base model and setting chat template to empty."
            )
            tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}"

    def run(
        self,
        messages: Union[str, List[dict]],
        target: str,  # name of the targeted function
        target_index: int,  # Saving the index of the targeted function in the payload
        attack_location: str,  # description or name
        saving_path: str,  # path to save the logs
        run_number: int,  # run number
        expected_target: str,  # expected_first_name_function
        is_verbose: bool,  # fine-grained track of the progress
    ) -> GCGResult:
        model = self.model
        tokenizer = self.tokenizer
        config = self.config

        if config.seed is not None:
            set_seed(config.seed)
            torch.use_deterministic_algorithms(True, warn_only=True)

        messages = copy.deepcopy(messages)

        print("messages", messages)

        self.epoch = 0
        self.run_number = run_number
        self.saving_path = saving_path
        self.expected_target_str = expected_target
        self.first_difference = 5
        self.expected_token_str = ""
        self.target_token_str = ""

        self.patient_time = 5

        self.is_verbose = is_verbose

        list_name_target = []
        list_description_target = []
        list_function = []
        list_first_name_function = []

        assert isinstance(target_index, int), "Invalid target_index"

        for d in messages:

            target_function = d["functions_str"][target_index]
            target_function_dict = target_function
            list_name_target.append(target_function_dict["name"])
            list_description_target.append(target_function_dict["description"])

            for fct in d["functions_str"]:
                list_function.append(fct)

        for x in list_function:
            tool = x
            list_first_name_function.append(tool["name"].split("_")[0])

        if attack_location == "description":

            # Append the GCG string at the end of the target description if location not specified
            if not any(["{optim_str}" in list_description_target]):  ## Look for changes in function description
                target_func_json = messages[-1]["functions_str"][target_index]
                target_func_dict = target_func_json
                target_func_dict["description"] += "{optim_str}"
                messages[-1]["functions_str"][target_index] = json.loads(json.dumps(target_func_dict))

        conversation = [
            {"role": "user", "content": messages[-1]["query"]},
        ]

        template = tokenizer.apply_chat_template(
            conversation=conversation,
            tools=[{"type": "function", "function": x} for x in messages[-1]["functions_str"]],
            tokenize=False,
            add_generation_prompt=True,
        )

        # Remove the BOS token -- this will get added when tokenizing
        if tokenizer.bos_token and template.startswith(tokenizer.bos_token):
            template = template.replace(tokenizer.bos_token, "")

        print("template", template)

        before_str, after_str = template.split("{optim_str}")

        target = " " + target if config.add_space_before_target else target

        tokenizer.pad_token = tokenizer.eos_token

        # Tokenize everything that does not get optimized
        before_ids = tokenizer([before_str], padding=True, add_special_tokens=True, return_tensors="pt")[
            "input_ids"
        ].to(model.device, torch.int64)
        after_ids = tokenizer([after_str], add_special_tokens=False, return_tensors="pt")["input_ids"].to(
            model.device, torch.int64
        )
        target_ids = tokenizer([target], add_special_tokens=False, return_tensors="pt")["input_ids"].to(
            model.device, torch.int64
        )

        # Embed everything that does not get optimized
        embedding_layer = self.embedding_layer
        before_embeds, after_embeds, target_embeds = [
            embedding_layer(ids) for ids in (before_ids, after_ids, target_ids)
        ]

        # Compute the KV Cache for tokens that appear before the optimized tokens
        if config.use_prefix_cache:
            with torch.no_grad():
                output = model(inputs_embeds=before_embeds, use_cache=True)
                self.prefix_cache = output.past_key_values

        self.target_ids = target_ids
        self.before_embeds = before_embeds
        self.after_embeds = after_embeds
        self.target_embeds = target_embeds

        self.expected_target = tokenizer([self.expected_target_str], add_special_tokens=False, return_tensors="pt")[
            "input_ids"
        ].to(model.device, torch.int64)

        if config.probe_sampling_config:
            assert (
                self.draft_model and self.draft_tokenizer and self.draft_embedding_layer
            ), "Draft model wasn't properly set up."

            # Tokenize everything that doesn't get optimized for the draft model
            draft_before_ids = self.draft_tokenizer(
                [before_str], padding=True, add_special_tokens=True, return_tensors="pt"
            )["input_ids"].to(model.device, torch.int64)
            draft_after_ids = self.draft_tokenizer([after_str], add_special_tokens=False, return_tensors="pt")[
                "input_ids"
            ].to(model.device, torch.int64)
            self.draft_target_ids = self.draft_tokenizer([target], add_special_tokens=False, return_tensors="pt")[
                "input_ids"
            ].to(model.device, torch.int64)

            (
                self.draft_before_embeds,
                self.draft_after_embeds,
                self.draft_target_embeds,
            ) = [
                self.draft_embedding_layer(ids)
                for ids in (
                    draft_before_ids,
                    draft_after_ids,
                    self.draft_target_ids,
                )
            ]

            if config.use_prefix_cache:
                with torch.no_grad():
                    output = self.draft_model(inputs_embeds=self.draft_before_embeds, use_cache=True)
                    self.draft_prefix_cache = output.past_key_values

        # Initialize the attack buffer
        buffer = self.init_buffer()
        optim_ids = buffer.get_best_ids()

        losses = []
        optim_strings = []

        save_results = []

        for current_iter in tqdm(range(config.num_steps)):
            # Compute the token gradient
            optim_ids_onehot_grad = self.compute_token_gradient(optim_ids)

            with torch.no_grad():

                # Sample candidate token sequences based on the token gradient
                sampled_ids = sample_ids_from_grad(
                    optim_ids.squeeze(0),
                    optim_ids_onehot_grad.squeeze(0),
                    config.search_width,
                    config.topk,
                    config.n_replace,
                    not_allowed_ids=self.not_allowed_ids,
                )

                if config.filter_ids:
                    sampled_ids = filter_ids(sampled_ids, tokenizer)

                new_search_width = sampled_ids.shape[0]

                # Compute loss on all candidate sequences
                batch_size = new_search_width if config.batch_size is None else config.batch_size
                if self.prefix_cache:
                    input_embeds = torch.cat(
                        [
                            embedding_layer(sampled_ids),
                            after_embeds.repeat(new_search_width, 1, 1),
                            target_embeds.repeat(new_search_width, 1, 1),
                        ],
                        dim=1,
                    )
                else:
                    input_embeds = torch.cat(
                        [
                            before_embeds.repeat(new_search_width, 1, 1),
                            embedding_layer(sampled_ids),
                            after_embeds.repeat(new_search_width, 1, 1),
                            target_embeds.repeat(new_search_width, 1, 1),
                        ],
                        dim=1,
                    )

                if self.config.probe_sampling_config is None:
                    loss = find_executable_batch_size(self._compute_candidates_loss_original, batch_size)(input_embeds)
                    current_loss = loss.min().item()
                    optim_ids = sampled_ids[loss.argmin()].unsqueeze(0)
                else:
                    current_loss, optim_ids = find_executable_batch_size(
                        self._compute_candidates_loss_probe_sampling, batch_size
                    )(
                        input_embeds,
                        sampled_ids,
                    )

                # Update the buffer based on the loss
                losses.append(current_loss)
                if buffer.size == 0 or current_loss < buffer.get_highest_loss():
                    buffer.add(current_loss, optim_ids)

            optim_ids = buffer.get_best_ids()
            optim_str = tokenizer.batch_decode(optim_ids)[0]
            optim_strings.append(optim_str)

            token_ids = torch.tensor(
                before_ids.squeeze().tolist() + optim_ids.squeeze().tolist() + after_ids.squeeze().tolist()
            ).to(model.device)
            input_length = token_ids.shape[0]
            outputs = model.generate(
                token_ids.unsqueeze(0),
                do_sample=False,
                max_new_tokens=100,
                return_dict_in_generate=True,
                output_logits=True,
            )  # , use_cache=True)
            generated_tokens = outputs.sequences[:, input_length:]

            decoded_output = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)[0]

            if self.is_verbose:
                print(
                    "Decoded Output:",
                    tokenizer.batch_decode(outputs.sequences[:, :], skip_special_tokens=False)[0],
                )

            list_first_name_function = [self.expected_token_str]

            # Get token ID for "first_name_target"
            self.target_token_str = list_name_target[-1]  ## ADDED
            first_name_target = self.target_token_str
            target_token_id = tokenizer.encode(first_name_target, add_special_tokens=False)[0]

            logit_results = []

            for step, token_id in enumerate(generated_tokens[0]):
                token_str = tokenizer.decode(token_id)  # Convert token ID back to text

                for first_name_funct in list_first_name_function:
                    if first_name_funct in token_str:
                        logits = outputs.logits[step][0]  # Get logits for this step
                        probabilities = F.softmax(logits, dim=0)  # Apply softmax to get normalized probabilities

                        sorted_indices = torch.argsort(logits, descending=True)  # Sort logits
                        rank = (sorted_indices == target_token_id).nonzero(as_tuple=True)[0].item()  # Find rank

                        weight = probabilities[target_token_id].item()  # Get the normalized probability
                        target_token_to_output = tokenizer.decode(target_token_id)

                        if self.is_verbose:
                            # print the first token that differs from the original generation compared to target
                            print([self.tokenizer.decode(sorted_indices[i].item()) for i in range(10)])

                        # Get best-ranked token (highest logit)
                        best_token_id = sorted_indices[0].item()
                        best_token_str = tokenizer.decode(best_token_id)
                        best_weight = probabilities[best_token_id].item()

                        logit_results.append(
                            (
                                target_token_to_output,
                                rank,
                                weight,
                                best_token_str,
                                best_weight,
                            )
                        )

                        break

            if self.is_verbose:

                for (
                    token_str,
                    rank,
                    weight,
                    best_token_str,
                    best_weight,
                ) in logit_results:
                    print(f"Generated token: {target_token_to_output}, Rank: {rank}, Weight: {weight}")
                    print(f"Best-ranked token: {best_token_str}, Weight: {best_weight}\n")

                print(f"Testing current string - Attack location:\n{attack_location}\n")
                print(f"Prompt:\n{optim_str}\n")
                print(f"Generation:\n{decoded_output}")

            save_results.append(
                {
                    "epoch": current_iter,
                    "loss": current_loss,
                    "optim_str": list_description_target[-1] + optim_str,
                    "output": decoded_output,
                    "logit_results": logit_results,
                    "optim_ids": optim_ids.tolist(),
                    "template": template,
                }
            )

            df = pd.DataFrame(save_results)

            df.to_json(
                saving_path + f"results_gcg_{run_number}.json",
                orient="records",
                indent=4,
            )

            self.epoch = current_iter

            buffer.log_buffer(tokenizer)

            if self.is_verbose:

                print("\n")
                print("!!!START of sanity check!!!")
                print("\n")

                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################

                current_template_before, current_template_after = template.split("{optim_str}")
                input_template = current_template_before + optim_str + current_template_after
                input_template = tokenizer.encode(input_template, add_special_tokens=True, return_tensors="pt").to(
                    model.device
                )  # ["input_ids"]
                input_length = input_template.shape[1]
                outputs = model.generate(
                    input_template,
                    do_sample=False,
                    max_new_tokens=100,
                    return_dict_in_generate=True,
                    output_logits=True,
                )  # , use_cache=True)
                generated_tokens = outputs.sequences[:, input_length:]

                decoded_output = tokenizer.batch_decode(generated_tokens, skip_special_tokens=False)[0]

                print(
                    "Decoded Output:",
                    tokenizer.batch_decode(outputs.sequences[:, :], skip_special_tokens=False)[0],
                )
                # decoded_output = "<|python_tag|>" + decoded_output.split("<|python_tag|>")[1]
                # print("REAL OUTPUT", decoded_output)

                list_first_name_function = [self.expected_token_str]

                # Get token ID for "first_name_target"
                first_name_target = self.target_token_str
                target_token_id = tokenizer.encode(first_name_target, add_special_tokens=False)[0]

                # Store the rank of "calculate_standard_deviation" only when function names are generated
                logit_results = []
                # Iterate over generated tokens and check when function names are predicted
                for step, token_id in enumerate(generated_tokens[0]):
                    token_str = tokenizer.decode(token_id)  # Convert token ID back to text

                    for first_name_funct in list_first_name_function:
                        if first_name_funct in token_str:
                            logits = outputs.logits[step][0]  # Get logits for this step
                            probabilities = F.softmax(logits, dim=0)  # Apply softmax to get normalized probabilities

                            sorted_indices = torch.argsort(logits, descending=True)  # Sort logits
                            rank = (sorted_indices == target_token_id).nonzero(as_tuple=True)[0].item()  # Find rank

                            weight = probabilities[target_token_id].item()  # Get the normalized probability
                            target_token_to_output = tokenizer.decode(target_token_id)

                            print([self.tokenizer.decode(sorted_indices[i].item()) for i in range(10)])

                            # Get best-ranked token (highest logit)
                            best_token_id = sorted_indices[0].item()
                            best_token_str = tokenizer.decode(best_token_id)
                            best_weight = probabilities[best_token_id].item()  # Use softmax probability instead

                            logit_results.append(
                                (
                                    target_token_to_output,
                                    rank,
                                    weight,
                                    best_token_str,
                                    best_weight,
                                )
                            )

                            break

                # Print rank & weight of "calculate_standard_deviation" vs. best-ranked token
                for (
                    token_str,
                    rank,
                    weight,
                    best_token_str,
                    best_weight,
                ) in logit_results:
                    print(f"Generated token: {target_token_to_output}, Rank: {rank}, Weight: {weight}")
                    print(f"Best-ranked token: {best_token_str}, Weight: {best_weight}\n")

                print(f"Testing current string - Attack location:\n{attack_location}\n")
                print(f"Prompt:\n{optim_str}\n")
                print(f"Generation:\n{decoded_output}")

                torch.cuda.empty_cache()
                torch.cuda.reset_peak_memory_stats()

                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################
                ####################################################################################################################################################################################

                print("\n")
                print("!!!END of sanity check!!!")
                print("\n")

            if self.stop_flag:
                logger.info("Early stopping due to finding a perfect match.")
                break

        min_loss_index = losses.index(min(losses))

        result = GCGResult(
            best_loss=losses[min_loss_index],
            best_string=optim_strings[min_loss_index],
            losses=losses,
            strings=optim_strings,
        )

        return result

    def init_buffer(self) -> AttackBuffer:
        model = self.model
        tokenizer = self.tokenizer
        config = self.config

        logger.info(f"Initializing attack buffer of size {config.buffer_size}...")

        # Create the attack buffer and initialize the buffer ids
        buffer = AttackBuffer(config.buffer_size)

        if isinstance(config.optim_str_init, str):
            init_optim_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")[
                "input_ids"
            ].to(model.device)
            if config.buffer_size > 1:
                init_buffer_ids = (
                    tokenizer(INIT_CHARS, add_special_tokens=False, return_tensors="pt")["input_ids"]
                    .squeeze()
                    .to(model.device)
                )
                init_indices = torch.randint(
                    0,
                    init_buffer_ids.shape[0],
                    (config.buffer_size - 1, init_optim_ids.shape[1]),
                )
                init_buffer_ids = torch.cat([init_optim_ids, init_buffer_ids[init_indices]], dim=0)
            else:
                init_buffer_ids = init_optim_ids

        else:  # assume list
            if len(config.optim_str_init) != config.buffer_size:
                logger.warning(
                    f"Using {len(config.optim_str_init)} initializations but buffer size is set to {config.buffer_size}"
                )
            try:
                init_buffer_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")[
                    "input_ids"
                ].to(model.device)
            except ValueError:
                logger.error("Unable to create buffer. Ensure that all initializations tokenize to the same length.")

        true_buffer_size = max(1, config.buffer_size)

        # Compute the loss on the initial buffer entries
        if self.prefix_cache:
            init_buffer_embeds = torch.cat(
                [
                    self.embedding_layer(init_buffer_ids),
                    self.after_embeds.repeat(true_buffer_size, 1, 1),
                    self.target_embeds.repeat(true_buffer_size, 1, 1),
                ],
                dim=1,
            )
        else:
            init_buffer_embeds = torch.cat(
                [
                    self.before_embeds.repeat(true_buffer_size, 1, 1),
                    self.embedding_layer(init_buffer_ids),
                    self.after_embeds.repeat(true_buffer_size, 1, 1),
                    self.target_embeds.repeat(true_buffer_size, 1, 1),
                ],
                dim=1,
            )

        init_buffer_losses = find_executable_batch_size(self._compute_candidates_loss_original, true_buffer_size)(
            init_buffer_embeds
        )

        # Populate the buffer
        for i in range(true_buffer_size):
            buffer.add(init_buffer_losses[i], init_buffer_ids[[i]])

        buffer.log_buffer(tokenizer)

        logger.info("Initialized attack buffer.")

        return buffer

    def compute_token_gradient(
        self,
        optim_ids: Tensor,
    ) -> Tensor:
        """Computes the gradient of the GCG loss w.r.t the one-hot token matrix.

        Args:
            optim_ids : Tensor, shape = (1, n_optim_ids)
                the sequence of token ids that are being optimized
        """
        model = self.model
        embedding_layer = self.embedding_layer

        # Create the one-hot encoding matrix of our optimized token ids
        optim_ids_onehot = torch.nn.functional.one_hot(optim_ids, num_classes=embedding_layer.num_embeddings)
        optim_ids_onehot = optim_ids_onehot.to(model.device, model.dtype)
        optim_ids_onehot.requires_grad_()

        # (1, num_optim_tokens, vocab_size) @ (vocab_size, embed_dim) -> (1, num_optim_tokens, embed_dim)
        optim_embeds = optim_ids_onehot @ embedding_layer.weight

        if self.prefix_cache:
            input_embeds = torch.cat([optim_embeds, self.after_embeds, self.target_embeds], dim=1)
            output = model(
                inputs_embeds=input_embeds,
                past_key_values=self.prefix_cache,
                use_cache=True,
            )

        else:
            input_embeds = torch.cat(
                [
                    self.before_embeds,
                    optim_embeds,
                    self.after_embeds,
                    # self.target_embeds,
                ],
                dim=1,
            )
            # output = model(inputs_embeds=input_embeds)

        logits_list = []  # To store logits for each timestep
        current_embeds = input_embeds  # Start with the full initial input sequence

        for i in range(self.target_ids.shape[1]):

            output = model(inputs_embeds=current_embeds)
            logits = output.logits  # shape: (batch_size, seq_len, vocab_size)

            logits_list.append(logits[:, -1, :].unsqueeze(1))  # Store the latest logits

            next_token = logits[:, -1, :].argmax(dim=-1, keepdim=True)

            next_token_embed = embedding_layer(next_token)

            current_embeds = torch.cat([current_embeds, next_token_embed], dim=1)

        shift_logits = torch.cat(logits_list, dim=1)
        shift_labels = self.target_ids

        # Get the predicted token ids
        predicted_ids = torch.argmax(shift_logits, dim=-1)
        sorted_indices = torch.argsort(shift_logits, dim=-1, descending=True)

        # Decode the predicted tokens
        predicted_tokens = [self.tokenizer.decode(predicted_ids[0, i].item()) for i in range(predicted_ids.shape[1])]
        target_tokens = [self.tokenizer.decode(shift_labels[0, i].item()) for i in range(shift_labels.shape[1])]

        first_difference = next(
            (
                i
                for i, (pred_token, target_token) in enumerate(zip(predicted_tokens, target_tokens))
                if pred_token != target_token
            ),
            None,
        )

        if first_difference is not None:
            self.first_difference = first_difference

        # Top-10 token ids at each position in the sequence
        top_k = 10
        softmax_probs = F.softmax(shift_logits, dim=-1)

        if self.is_verbose:

            # Assuming batch_size = 1
            for i in range(
                min(shift_labels.shape[1], 10)
            ):  # For the first 10 positions in the target sequence - 30 for granite

                top_indices = sorted_indices[0, i, :top_k]  # Top-k sorted indices for position i

                # Get the probabilities for these top indices
                top_probs = softmax_probs[0, i, top_indices]

                # Pair decoded tokens with their probabilities
                decoded = [self.tokenizer.decode([idx.item()]) for idx in top_indices]
                probs = [prob.item() for prob in top_probs]

                print(f"\n RANK DECODED GENERATION position {i}: ", decoded)
                print(f"\n SOFTMAX GENERATION position {i}: ", probs)

                if i == self.first_difference:  # 26 for granite, 5 for llama

                    # --- EXPECTED token analysis ---

                    expected_token_id = self.expected_target[0][i]
                    expected_token_str = self.tokenizer.decode([expected_token_id])
                    self.expected_token_str = expected_token_str

                    rank_tensor = (sorted_indices[0, i] == expected_token_id).nonzero(as_tuple=True)[0]

                    if rank_tensor.numel() == 0:
                        expected_rank = None
                        print(
                            f"Warning: expected token ID {expected_token_id} not found in sorted indices at position {i}."
                        )
                    else:
                        expected_rank = rank_tensor.item()

                    # Get softmax probability of expected token
                    expected_prob = softmax_probs[0, i, expected_token_id].item()

                    print(
                        f"\n EXPECTED NAME TOKEN position {i}: (token='{expected_token_str}', rank={expected_rank}, softmax_prob={expected_prob})"
                    )

                    # --- TARGET token analysis ---

                    target_token_id = shift_labels[0, i].item()
                    target_token_str = self.tokenizer.decode([target_token_id])
                    self.target_token_str = target_token_str

                    # Get rank: where is target_token_id in the sorted list?
                    rank_tensor = (sorted_indices[0, i] == target_token_id).nonzero(as_tuple=True)[0]

                    if rank_tensor.numel() == 0:
                        # Shouldn't happen, but just in case
                        target_rank = None
                        print(
                            f"Warning: target token ID {target_token_id} not found in sorted indices at position {i}."
                        )
                    else:
                        target_rank = rank_tensor.item()  # Rank in sorted list (0 = best)

                    # Get softmax probability of target token
                    target_prob = softmax_probs[0, i, target_token_id].item()

                    print(
                        f"\n TARGET position {i}: (token='{target_token_str}', rank={target_rank}, softmax_prob={target_prob})"
                    )

                print("- - " * 20 + "-")

        if self.config.use_mellowmax:
            label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
            loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
        else:
            loss_generation = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), predicted_ids.view(-1)
            ).item()
            print("TOKEN(S) GENERATION LOSS: ", loss_generation)

            loss = torch.nn.functional.cross_entropy(
                shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
            )  # , weight=weight)
            print("TOKEN(S) TARGET LOSS", loss.item())

        optim_ids_onehot_grad = torch.autograd.grad(outputs=[loss], inputs=[optim_ids_onehot])[0]

        results_file = self.saving_path + f"results_token_gradient_{self.run_number}.json"

        if os.path.exists(results_file):
            with open(results_file, "r") as f:
                results = json.load(f)
        else:
            results = []

        results.append(
            {
                "epoch": self.epoch,
                "loss_expected": loss_generation,
                "loss_target": loss.item(),
            }
        )

        df = pd.DataFrame(results)

        df.to_json(results_file, orient="records", indent=4)

        return optim_ids_onehot_grad

    def _compute_candidates_loss_original(
        self,
        search_batch_size: int,
        input_embeds: Tensor,
    ) -> Tensor:
        """Computes the GCG loss on all candidate token id sequences.

        Args:
            search_batch_size : int
                the number of candidate sequences to evaluate in a given batch
            input_embeds : Tensor, shape = (search_width, seq_len, embd_dim)
                the embeddings of the `search_width` candidate sequences to evaluate
        """
        all_loss = []
        prefix_cache_batch = []

        for i in range(0, input_embeds.shape[0], search_batch_size):
            with torch.no_grad():
                input_embeds_batch = input_embeds[i : i + search_batch_size]
                current_batch_size = input_embeds_batch.shape[0]

                if self.prefix_cache:
                    if not prefix_cache_batch or current_batch_size != search_batch_size:
                        prefix_cache_batch = [
                            [x.expand(current_batch_size, -1, -1, -1) for x in self.prefix_cache[i]]
                            for i in range(len(self.prefix_cache))
                        ]

                    outputs = self.model(
                        inputs_embeds=input_embeds_batch,
                        past_key_values=prefix_cache_batch,
                        use_cache=True,
                    )
                else:
                    outputs = self.model(inputs_embeds=input_embeds_batch)

                logits = outputs.logits

                tmp = input_embeds.shape[1] - self.target_ids.shape[1]
                shift_logits = logits[..., tmp - 1 : -1, :].contiguous()
                shift_labels = self.target_ids.repeat(current_batch_size, 1)

                if self.config.use_mellowmax:
                    label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
                else:
                    loss = torch.nn.functional.cross_entropy(
                        shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1),
                        reduction="none",
                    )

                loss = loss.view(current_batch_size, -1).mean(dim=-1)
                all_loss.append(loss)

                if self.config.early_stop:
                    if torch.any(torch.all(torch.argmax(shift_logits, dim=-1) == shift_labels, dim=-1)).item():
                        self.stop_flag = True

                del outputs
                gc.collect()
                torch.cuda.empty_cache()

        return torch.cat(all_loss, dim=0)

    def _compute_candidates_loss_probe_sampling(
        self,
        search_batch_size: int,
        input_embeds: Tensor,
        sampled_ids: Tensor,
    ) -> Tuple[float, Tensor]:
        """Computes the GCG loss using probe sampling (https://arxiv.org/abs/2403.01251).

        Args:
            search_batch_size : int
                the number of candidate sequences to evaluate in a given batch
            input_embeds : Tensor, shape = (search_width, seq_len, embd_dim)
                the embeddings of the `search_width` candidate sequences to evaluate
            sampled_ids: Tensor, all candidate token id sequences. Used for further sampling.

        Returns:
            A tuple of (min_loss: float, corresponding_sequence: Tensor)

        """
        probe_sampling_config = self.config.probe_sampling_config
        assert probe_sampling_config, "Probe sampling config wasn't set up properly."

        B = input_embeds.shape[0]
        probe_size = B // probe_sampling_config.sampling_factor
        probe_idxs = torch.randperm(B)[:probe_size].to(input_embeds.device)
        probe_embeds = input_embeds[probe_idxs]

        def _compute_probe_losses(result_queue: queue.Queue, search_batch_size: int, probe_embeds: Tensor) -> None:
            probe_losses = self._compute_candidates_loss_original(search_batch_size, probe_embeds)
            result_queue.put(("probe", probe_losses))

        def _compute_draft_losses(
            result_queue: queue.Queue,
            search_batch_size: int,
            draft_sampled_ids: Tensor,
        ) -> None:
            assert (
                self.draft_model and self.draft_embedding_layer
            ), "Draft model and embedding layer weren't initialized properly."

            draft_losses = []
            draft_prefix_cache_batch = None
            for i in range(0, B, search_batch_size):
                with torch.no_grad():
                    batch_size = min(search_batch_size, B - i)
                    draft_sampled_ids_batch = draft_sampled_ids[i : i + batch_size]

                    if self.draft_prefix_cache:
                        if not draft_prefix_cache_batch or batch_size != search_batch_size:
                            draft_prefix_cache_batch = [
                                [x.expand(batch_size, -1, -1, -1) for x in self.draft_prefix_cache[i]]
                                for i in range(len(self.draft_prefix_cache))
                            ]
                        draft_embeds = torch.cat(
                            [
                                self.draft_embedding_layer(draft_sampled_ids_batch),
                                self.draft_after_embeds.repeat(batch_size, 1, 1),
                                self.draft_target_embeds.repeat(batch_size, 1, 1),
                            ],
                            dim=1,
                        )
                        draft_output = self.draft_model(
                            inputs_embeds=draft_embeds,
                            past_key_values=draft_prefix_cache_batch,
                        )
                    else:
                        draft_embeds = torch.cat(
                            [
                                self.draft_before_embeds.repeat(batch_size, 1, 1),
                                self.draft_embedding_layer(draft_sampled_ids_batch),
                                self.draft_after_embeds.repeat(batch_size, 1, 1),
                                self.draft_target_embeds.repeat(batch_size, 1, 1),
                            ],
                            dim=1,
                        )
                        draft_output = self.draft_model(inputs_embeds=draft_embeds)

                    draft_logits = draft_output.logits
                    tmp = draft_embeds.shape[1] - self.draft_target_ids.shape[1]
                    shift_logits = draft_logits[..., tmp - 1 : -1, :].contiguous()
                    shift_labels = self.draft_target_ids.repeat(batch_size, 1)

                    if self.config.use_mellowmax:
                        label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
                        loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
                    else:
                        loss = (
                            torch.nn.functional.cross_entropy(
                                shift_logits.view(-1, shift_logits.size(-1)),
                                shift_labels.view(-1),
                                reduction="none",
                            )
                            .view(batch_size, -1)
                            .mean(dim=-1)
                        )

                    draft_losses.append(loss)

            draft_losses = torch.cat(draft_losses)
            result_queue.put(("draft", draft_losses))

        def _convert_to_draft_tokens(token_ids: Tensor) -> Tensor:
            decoded_text_list = self.tokenizer.batch_decode(token_ids)
            assert self.draft_tokenizer, "Draft tokenizer wasn't properly initialized."
            return self.draft_tokenizer(
                decoded_text_list,
                add_special_tokens=False,
                padding=True,
                return_tensors="pt",
            )["input_ids"].to(self.draft_model.device, torch.int64)

        result_queue = queue.Queue()
        draft_sampled_ids = _convert_to_draft_tokens(sampled_ids)

        # Step 1. Compute loss of all candidates using the draft model
        draft_thread = threading.Thread(
            target=_compute_draft_losses,
            args=(result_queue, search_batch_size, draft_sampled_ids),
        )

        # Step 2. In parallel to 1., compute loss of the probe set on target model
        probe_thread = threading.Thread(
            target=_compute_probe_losses,
            args=(result_queue, search_batch_size, probe_embeds),
        )

        draft_thread.start()
        probe_thread.start()

        draft_thread.join()
        probe_thread.join()

        results = {}
        while not result_queue.empty():
            key, value = result_queue.get()
            results[key] = value

        probe_losses = results["probe"]
        draft_losses = results["draft"]

        # Step 3. Calculate agreement score using Spearman correlation
        draft_probe_losses = draft_losses[probe_idxs]
        rank_correlation = spearmanr(
            probe_losses.cpu().type(torch.float32).numpy(),
            draft_probe_losses.cpu().type(torch.float32).numpy(),
        ).correlation
        # normalized from [-1, 1] to [0, 1]
        alpha = (1 + rank_correlation) / 2

        # Step 4. Calculate the filtered set and evaluate using the target model.
        R = probe_sampling_config.r
        filtered_size = int((1 - alpha) * B / R)
        filtered_size = max(1, min(filtered_size, B))

        _, top_indices = torch.topk(draft_losses, k=filtered_size, largest=False)

        filtered_embeds = input_embeds[top_indices]
        filtered_losses = self._compute_candidates_loss_original(search_batch_size, filtered_embeds)

        # Step 5. Return best loss between probe set and filtered set
        best_probe_loss = probe_losses.min().item()
        best_filtered_loss = filtered_losses.min().item()

        probe_ids = sampled_ids[probe_idxs]
        filtered_ids = sampled_ids[top_indices]
        return (
            (best_probe_loss, probe_ids[probe_losses.argmin()].unsqueeze(0))
            if best_probe_loss < best_filtered_loss
            else (
                best_filtered_loss,
                filtered_ids[filtered_losses.argmin()].unsqueeze(0),
            )
        )
