import copy
import gc
import logging
import queue
import threading

from dataclasses import dataclass
from tqdm import tqdm
from typing import List, Optional, Tuple, Union

import torch
import transformers
from torch import Tensor
from transformers import set_seed
from scipy.stats import spearmanr

from utils import (
    INIT_CHARS,
    configure_pad_token,
    find_executable_batch_size,
    get_nonascii_toks,
    mellowmax,
)

logger = logging.getLogger("nanogcg")
if not logger.hasHandlers():
    handler = logging.StreamHandler()
    formatter = logging.Formatter(
        "%(asctime)s [%(filename)s:%(lineno)d] %(message)s",
        datefmt="%Y-%m-%d %H:%M:%S",
    )
    handler.setFormatter(formatter)
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)


@dataclass
class ProbeSamplingConfig:
    draft_model: transformers.PreTrainedModel
    draft_tokenizer: transformers.PreTrainedTokenizer
    r: int = 8
    sampling_factor: int = 16


@dataclass
class GCGConfig:
    num_steps: int = 250
    optim_str_init: Union[str, List[str]] = "x x x x x x x x x x x x x x x x x x x x"
    search_width: int = 512
    batch_size: int = None
    topk: int = 256
    n_replace: int = 1
    buffer_size: int = 0
    use_mellowmax: bool = False
    mellowmax_alpha: float = 1.0
    early_stop: bool = False
    use_prefix_cache: bool = True
    allow_non_ascii: bool = False
    filter_ids: bool = True
    add_space_before_target: bool = False
    seed: int = None
    verbosity: str = "INFO"
    probe_sampling_config: Optional[ProbeSamplingConfig] = None


@dataclass
class GCGResult:
    best_loss: float
    best_string: str
    losses: List[float]
    strings: List[str]


class AttackBuffer:
    def __init__(self, size: int):
        self.buffer = []  # elements are (loss: float, optim_ids: Tensor)
        self.size = size

    def add(self, loss: float, optim_ids: Tensor) -> None:
        if self.size == 0:
            self.buffer = [(loss, optim_ids)]
            return

        if len(self.buffer) < self.size:
            self.buffer.append((loss, optim_ids))
        else:
            self.buffer[-1] = (loss, optim_ids)

        self.buffer.sort(key=lambda x: x[0])

    def get_best_ids(self) -> Tensor:
        return self.buffer[0][1]

    def get_lowest_loss(self) -> float:
        return self.buffer[0][0]

    def get_highest_loss(self) -> float:
        return self.buffer[-1][0]

    def log_buffer(self, tokenizer):
        message = "buffer:"
        for loss, ids in self.buffer:
            optim_str = tokenizer.batch_decode(ids)[0]
            optim_str = optim_str.replace("\\", "\\\\")
            optim_str = optim_str.replace("\n", "\\n")
            message += f"\nloss: {loss}" + f" | string: {optim_str}"
        logger.info(message)


def sample_ids_from_grad(
    ids: Tensor,
    grad: Tensor,
    search_width: int,
    topk: int = 256,
    n_replace: int = 1,
    not_allowed_ids: Tensor = False,
):
    """Returns `search_width` combinations of token ids based on the token gradient.

    Args:
        ids : Tensor, shape = (n_optim_ids)
            the sequence of token ids that are being optimized
        grad : Tensor, shape = (n_optim_ids, vocab_size)
            the gradient of the GCG loss computed with respect to the one-hot token embeddings
        search_width : int
            the number of candidate sequences to return
        topk : int
            the topk to be used when sampling from the gradient
        n_replace : int
            the number of token positions to update per sequence
        not_allowed_ids : Tensor, shape = (n_ids)
            the token ids that should not be used in optimization

    Returns:
        sampled_ids : Tensor, shape = (search_width, n_optim_ids)
            sampled token ids
    """
    n_optim_tokens = len(ids)
    original_ids = ids.repeat(search_width, 1)

    if not_allowed_ids is not None:
        grad[:, not_allowed_ids.to(grad.device)] = float("inf")

    topk_ids = (-grad).topk(topk, dim=1).indices

    sampled_ids_pos = torch.argsort(torch.rand((search_width, n_optim_tokens), device=grad.device))[..., :n_replace]
    sampled_ids_val = torch.gather(
        topk_ids[sampled_ids_pos],
        2,
        torch.randint(0, topk, (search_width, n_replace, 1), device=grad.device),
    ).squeeze(2)

    new_ids = original_ids.scatter_(1, sampled_ids_pos, sampled_ids_val)

    return new_ids


def filter_ids(ids: Tensor, tokenizer: transformers.PreTrainedTokenizer):
    """Filters out sequeneces of token ids that change after retokenization.

    Args:
        ids : Tensor, shape = (search_width, n_optim_ids)
            token ids
        tokenizer : ~transformers.PreTrainedTokenizer
            the model's tokenizer

    Returns:
        filtered_ids : Tensor, shape = (new_search_width, n_optim_ids)
            all token ids that are the same after retokenization
    """
    ids_decoded = tokenizer.batch_decode(ids)
    filtered_ids = []

    for i in range(len(ids_decoded)):
        # Retokenize the decoded token ids
        ids_encoded = tokenizer(ids_decoded[i], return_tensors="pt", add_special_tokens=False).to(ids.device)["input_ids"][0]
        if torch.equal(ids[i], ids_encoded):
            filtered_ids.append(ids[i])

    if not filtered_ids:
        # This occurs in some cases, e.g. using the Llama-3 tokenizer with a bad initialization
        raise RuntimeError(
            "No token sequences are the same after decoding and re-encoding. "
            "Consider setting `filter_ids=False` or trying a different `optim_str_init`"
        )

    return torch.stack(filtered_ids)

class GCG:
    def __init__(
        self,
        model: transformers.PreTrainedModel,
        tokenizer: transformers.PreTrainedTokenizer,
        config: GCGConfig,
    ):
        self.model = model
        self.tokenizer = tokenizer
        self.config = config

        self.embedding_layer = model.get_input_embeddings()
        self.not_allowed_ids = None if config.allow_non_ascii else get_nonascii_toks(tokenizer, device=model.device)
        
        self.prefix_cache_list = None 
        self.draft_prefix_cache_list = None

        self.stop_flag = False
        
        self.draft_model = None
        self.draft_tokenizer = None
        self.draft_embedding_layer = None
        if self.config.probe_sampling_config:
            self.draft_model = self.config.probe_sampling_config.draft_model
            self.draft_tokenizer = self.config.probe_sampling_config.draft_tokenizer
            self.draft_embedding_layer = self.draft_model.get_input_embeddings()
            if self.draft_tokenizer.pad_token is None:
                configure_pad_token(self.draft_tokenizer)

        if model.dtype in (torch.float32, torch.float64):
            logger.warning(f"Model is in {model.dtype}. Use a lower precision data type, if possible, for much faster optimization.")

        if model.device == torch.device("cpu"):
            logger.warning("Model is on the CPU. Use a hardware accelerator for faster optimization.")

        if not tokenizer.chat_template:
            logger.warning("Tokenizer does not have a chat template. Assuming base model and setting chat template to empty.")
            tokenizer.chat_template = "{% for message in messages %}{{ message['content'] }}{% endfor %}"


    def run(
        self,
        multi_messages: List[Union[str, List[dict]]],
        target: str,
    ) -> GCGResult:
        model = self.model
        tokenizer = self.tokenizer
        config = self.config

        if config.seed is not None:
            set_seed(config.seed)
            torch.use_deterministic_algorithms(True, warn_only=True)

        before_str_list = []
        after_str_list = []

        for messages in multi_messages:
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            else:
                messages = copy.deepcopy(messages)

            if not any(["{optim_str}" in d["content"] for d in messages]):
                messages[-1]["content"] = messages[-1]["content"] + "{optim_str}"

            template = tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
            if tokenizer.bos_token and template.startswith(tokenizer.bos_token):
                template = template.replace(tokenizer.bos_token, "")
            
            try:
                before_str, after_str = template.split("{optim_str}")
                before_str_list.append(before_str)
                after_str_list.append(after_str)
            except ValueError:
                raise ValueError("Each message in multi_messages must contain the '{optim_str}' placeholder exactly once.")


        target = " " + target if config.add_space_before_target else target

        before_ids_list = [tokenizer([s], padding=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64) for s in before_str_list]
        after_ids_list = [tokenizer([s], add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64) for s in after_str_list]
        self.target_ids = tokenizer([target], add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device, torch.int64)

        embedding_layer = self.embedding_layer
        self.before_embeds_list = [embedding_layer(ids) for ids in before_ids_list]
        self.after_embeds_list = [embedding_layer(ids) for ids in after_ids_list]
        self.target_embeds = embedding_layer(self.target_ids)

        if config.use_prefix_cache:
            self.prefix_cache_list = []
            with torch.no_grad():
                for before_embeds in self.before_embeds_list:
                    output = model(inputs_embeds=before_embeds, use_cache=True)
                    self.prefix_cache_list.append(output.past_key_values)
        
        if config.probe_sampling_config:
             raise NotImplementedError("Probe sampling for multi-prompt optimization is not implemented in this example.")


        buffer = self.init_buffer()
        optim_ids = buffer.get_best_ids()

        losses = []
        optim_strings = []

        for step in tqdm(range(config.num_steps)):
            # Compute the token gradient based on the total loss from all prompts
            optim_ids_onehot_grad = self.compute_token_gradient(optim_ids)

            with torch.no_grad():
                # Sample candidate token sequences based on the token gradient
                sampled_ids = sample_ids_from_grad(
                    optim_ids.squeeze(0),
                    optim_ids_onehot_grad.squeeze(0),
                    config.search_width,
                    config.topk,
                    config.n_replace,
                    not_allowed_ids=self.not_allowed_ids,
                )

                if config.filter_ids:
                    sampled_ids = filter_ids(sampled_ids, tokenizer)

                batch_size = config.batch_size if config.batch_size is not None else sampled_ids.shape[0]
                
                loss = find_executable_batch_size(self._compute_candidates_loss_original, batch_size)(sampled_ids)
                
                current_loss = loss.min().item()
                optim_ids = sampled_ids[loss.argmin()].unsqueeze(0)

                # Update the buffer based on the loss
                losses.append(current_loss)
                if buffer.size == 0 or current_loss < buffer.get_highest_loss():
                    buffer.add(current_loss, optim_ids)

            optim_ids = buffer.get_best_ids()
            optim_str = tokenizer.batch_decode(optim_ids)[0]
            optim_strings.append(optim_str)

            buffer.log_buffer(tokenizer)

            if self.stop_flag:
                logger.info("Early stopping due to finding a perfect match.")
                break
        
        min_loss_index = losses.index(min(losses))

        result = GCGResult(
            best_loss=losses[min_loss_index],
            best_string=optim_strings[min_loss_index],
            losses=losses,
            strings=optim_strings,
        )

        return result
    
    def init_buffer(self) -> AttackBuffer:
            
            model = self.model
            tokenizer = self.tokenizer
            config = self.config
            logger.info(f"Initializing attack buffer of size {config.buffer_size}...")
            buffer = AttackBuffer(config.buffer_size)
            if isinstance(config.optim_str_init, str):
                init_optim_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device)
                if config.buffer_size > 1:
                    init_buffer_ids = tokenizer(INIT_CHARS, add_special_tokens=False, return_tensors="pt")["input_ids"].squeeze().to(model.device)
                    init_indices = torch.randint(0, init_buffer_ids.shape[0], (config.buffer_size - 1, init_optim_ids.shape[1]))
                    init_buffer_ids = torch.cat([init_optim_ids, init_buffer_ids[init_indices]], dim=0)
                else:
                    init_buffer_ids = init_optim_ids
            else: # assume list
                if len(config.optim_str_init) != config.buffer_size:
                    logger.warning(f"Using {len(config.optim_str_init)} initializations but buffer size is set to {config.buffer_size}")
                try:
                    init_buffer_ids = tokenizer(config.optim_str_init, add_special_tokens=False, return_tensors="pt")["input_ids"].to(model.device)
                except ValueError:
                    logger.error("Unable to create buffer. Ensure that all initializations tokenize to the same length.")

            true_buffer_size = max(1, config.buffer_size)
            init_buffer_losses = find_executable_batch_size(self._compute_candidates_loss_original, true_buffer_size)(init_buffer_ids)
            
            for i in range(true_buffer_size):
                buffer.add(init_buffer_losses[i].item(), init_buffer_ids[[i]])

            buffer.log_buffer(tokenizer)
            logger.info("Initialized attack buffer.")
            return buffer


    def compute_token_gradient(
            self,
            optim_ids: Tensor,
        ) -> Tensor:
            model = self.model
            embedding_layer = self.embedding_layer

            optim_ids_onehot = torch.nn.functional.one_hot(optim_ids, num_classes=embedding_layer.num_embeddings)
            optim_ids_onehot = optim_ids_onehot.to(model.device, model.dtype)
            optim_ids_onehot.requires_grad_()

            optim_embeds = optim_ids_onehot @ embedding_layer.weight

            total_loss = 0
            num_prompts = len(self.before_embeds_list)
            
            for i in range(num_prompts):
                if self.prefix_cache_list:
                    input_embeds = torch.cat([optim_embeds, self.after_embeds_list[i], self.target_embeds], dim=1)
                    output = model(
                        inputs_embeds=input_embeds,
                        past_key_values=self.prefix_cache_list[i],
                        use_cache=True,
                    )
                else:
                    input_embeds = torch.cat(
                        [
                            self.before_embeds_list[i],
                            optim_embeds,
                            self.after_embeds_list[i],
                            self.target_embeds,
                        ],
                        dim=1,
                    )
                    output = model(inputs_embeds=input_embeds)
                
                logits = output.logits
                shift = input_embeds.shape[1] - self.target_ids.shape[1]
                shift_logits = logits[..., shift - 1 : -1, :].contiguous()
                shift_labels = self.target_ids

                if self.config.use_mellowmax:
                    label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
                    loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
                else:
                    loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
                
                total_loss += loss

            optim_ids_onehot_grad = torch.autograd.grad(outputs=[total_loss], inputs=[optim_ids_onehot])[0]

            return optim_ids_onehot_grad


    def _compute_candidates_loss_original(
        self,
        search_batch_size: int,
        sampled_ids: Tensor,
    ) -> Tensor:
        all_loss_total = torch.zeros(sampled_ids.shape[0], device=self.model.device)
        num_prompts = len(self.before_embeds_list)

        for i in range(num_prompts):
            
            before_embeds = self.before_embeds_list[i]
            after_embeds = self.after_embeds_list[i]
            prefix_cache = self.prefix_cache_list[i] if self.config.use_prefix_cache else None
            
            prefix_cache_batch = None

            for j in range(0, sampled_ids.shape[0], search_batch_size):
                with torch.no_grad():
                    batch_ids = sampled_ids[j:j + search_batch_size]
                    current_batch_size = batch_ids.shape[0]
                    optim_embeds_batch = self.embedding_layer(batch_ids)

                    if prefix_cache:
                        input_embeds_batch = torch.cat([
                            optim_embeds_batch,
                            after_embeds.repeat(current_batch_size, 1, 1),
                            self.target_embeds.repeat(current_batch_size, 1, 1)
                        ], dim=1)
                        
                        if not prefix_cache_batch or current_batch_size != search_batch_size:
                            prefix_cache_batch = [[x.expand(current_batch_size, -1, -1, -1) for x in layer_cache] for layer_cache in prefix_cache]

                        outputs = self.model(inputs_embeds=input_embeds_batch, past_key_values=prefix_cache_batch, use_cache=True)
                    else:
                        input_embeds_batch = torch.cat([
                            before_embeds.repeat(current_batch_size, 1, 1),
                            optim_embeds_batch,
                            after_embeds.repeat(current_batch_size, 1, 1),
                            self.target_embeds.repeat(current_batch_size, 1, 1),
                        ], dim=1)
                        outputs = self.model(inputs_embeds=input_embeds_batch)
                    
                    logits = outputs.logits
                    tmp = input_embeds_batch.shape[1] - self.target_ids.shape[1]
                    shift_logits = logits[..., tmp-1:-1, :].contiguous()
                    shift_labels = self.target_ids.repeat(current_batch_size, 1)

                    if self.config.use_mellowmax:
                        label_logits = torch.gather(shift_logits, -1, shift_labels.unsqueeze(-1)).squeeze(-1)
                        loss = mellowmax(-label_logits, alpha=self.config.mellowmax_alpha, dim=-1)
                    else:
                        loss = torch.nn.functional.cross_entropy(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1), reduction="none")
                        loss = loss.view(current_batch_size, -1).mean(dim=-1)

                    all_loss_total[j:j + search_batch_size] += loss
        
                    if self.config.early_stop and not self.stop_flag:
                        if torch.any(torch.all(torch.argmax(shift_logits, dim=-1) == shift_labels, dim=-1)).item():
                            self.stop_flag = True

        return all_loss_total / num_prompts
    

def run(
    model: transformers.PreTrainedModel,
    tokenizer: transformers.PreTrainedTokenizer,
    messages: Union[str, List[dict], List[Union[str, List[dict]]]],
    target: str,
    config: Optional[GCGConfig] = None,
) -> GCGResult:
    if config is None:
        config = GCGConfig()

    logger.setLevel(getattr(logging, config.verbosity))

    gcg = GCG(model, tokenizer, config)

    is_multi_prompt = isinstance(messages, list) and len(messages) > 0 and isinstance(messages[0], (list, str, dict))
    if not is_multi_prompt:
         messages = [messages]

    result = gcg.run(messages, target)
    return result