import numpy as np
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer


def load_model_and_tokenizer(
    model_path, tokenizer_path=None, device="cuda:0", **kwargs
):
    model = AutoModelForCausalLM.from_pretrained(
        model_path, torch_dtype=torch.float16, trust_remote_code=True, **kwargs
    )
    model = model.to(device).eval()

    tokenizer_path = model_path if tokenizer_path is None else tokenizer_path

    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_path, trust_remote_code=True, use_fast=False
    )

    if not tokenizer.pad_token:
        tokenizer.pad_token = tokenizer.eos_token

    return model, tokenizer


class SuffixManager:
    def __init__(self, *, tokenizer, use_system_instructions, conv_template):
        self.tokenizer = tokenizer
        self.use_system_instructions = use_system_instructions
        self.conv_template = conv_template

    def get_input_ids(self, messages, adv_suffix, target, static_only=False):
        """
        Compute token ids for given messages and target, along with slices tracking positions of important tokens.

        Parameters
        ----------
        messages : list of Message objects
            Messages in the conversation.
        adv_suffix : str
            Current adversarial suffix
        target : str
            Current target output for model
        """
        # This code was tested with llama-2 and vicuna_v1.1 templates but remove this check to experiment with others
        if self.conv_template.name not in ["llama-2", "vicuna_v1.1", "mistral"]:
            raise NotImplementedError

        num_tok_sep = len(
            self.tokenizer(self.conv_template.sep, add_special_tokens=False).input_ids
        )
        num_tok_sep2 = len(
            self.tokenizer(self.conv_template.sep2, add_special_tokens=False).input_ids
        )

        self.conv_template.messages = []

        if messages[0].content:
            self.conv_template.set_system_message(messages[0].content)
        if not self.use_system_instructions:
            self.conv_template.append_message(
                self.conv_template.roles[0], messages[1].content
            )  # user rules
            self.conv_template.append_message(
                self.conv_template.roles[1], messages[2].content
            )  # asst response
            user_msg = messages[3].content
        else:
            user_msg = messages[1].content

        self.conv_template.append_message(
            self.conv_template.roles[0], user_msg
        )  # user msg
        toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
        static_input_ids = torch.tensor(toks[: len(toks) - num_tok_sep])
        num_static_tokens = len(static_input_ids)

        if static_only:
            return static_input_ids

        self.conv_template.update_last_message(
            f"{user_msg} {adv_suffix}"
        )  # user msg + adv suffix
        toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
        optim_slice = slice(num_static_tokens, len(toks) - num_tok_sep)

        self.conv_template.append_message(self.conv_template.roles[1], None)
        toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
        assistant_role_slice = slice(optim_slice.stop, len(toks))

        self.conv_template.update_last_message(target)  # asst target
        toks = self.tokenizer(self.conv_template.get_prompt()).input_ids
        target_slice = slice(assistant_role_slice.stop, len(toks) - num_tok_sep2)
        loss_slice = slice(assistant_role_slice.stop - 1, len(toks) - num_tok_sep2 - 1)

        input_ids = torch.tensor(
            toks[: target_slice.stop]
        )  # Don't need final sep tokens

        return input_ids, optim_slice, target_slice, loss_slice


def sample_updates(token_ids, grad, batch_size, topk=256, not_allowed_tokens=None):
    if not_allowed_tokens is not None:
        grad[:, not_allowed_tokens.to(grad.device)] = np.infty

    top_indices = (-grad).topk(topk, dim=1).indices

    old_token_ids = token_ids.repeat(batch_size, 1)
    new_token_pos = torch.arange(
        0, len(token_ids), len(token_ids) / batch_size, device=grad.device
    ).type(torch.int64)
    new_token_val = torch.gather(
        top_indices[new_token_pos],
        1,
        torch.randint(0, topk, (batch_size, 1), device=grad.device),
    )
    new_token_ids = old_token_ids.scatter(1, new_token_pos.unsqueeze(-1), new_token_val)

    return new_token_ids


def batchify_kv_cache(prefix_cache, batch_size):
    batch_prefix_cache = []
    for k, v in prefix_cache:
        batch_prefix_cache.append(
            (k.repeat(batch_size, 1, 1, 1), v.repeat(batch_size, 1, 1, 1))
        )
    return tuple(batch_prefix_cache)


def get_nonascii_toks(tokenizer, device="cpu"):
    def is_ascii(s):
        return s.isascii() and s.isprintable()

    ascii_toks = []
    for i in range(3, tokenizer.vocab_size):
        if not is_ascii(tokenizer.decode([i])):
            ascii_toks.append(i)

    if tokenizer.bos_token_id is not None:
        ascii_toks.append(tokenizer.bos_token_id)
    if tokenizer.eos_token_id is not None:
        ascii_toks.append(tokenizer.eos_token_id)
    if tokenizer.pad_token_id is not None:
        ascii_toks.append(tokenizer.pad_token_id)
    if tokenizer.unk_token_id is not None:
        ascii_toks.append(tokenizer.unk_token_id)

    return torch.tensor(ascii_toks, device=device)
