# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utility class and functions.

Adapted from:
https://github.com/kmeng01/rome/blob/bef95a6afd2ca15d794bdd4e3ee0f24283f9b996/
"""

import re

import torch
import transformers


class ModelAndTokenizer:
    """An object to hold a GPT-style language model and tokenizer."""

    def __init__(
            self,
            model_name=None,
            model=None,
            tokenizer=None,
            low_cpu_mem_usage=False,
            torch_dtype=None,
            use_fast=True,
            device="cuda",
            attn_implementation="flash_attention_2",
    ):
        if tokenizer is None:
            assert model_name is not None
            tokenizer = transformers.AutoTokenizer.from_pretrained(model_name, use_fast=use_fast)
            if tokenizer.pad_token is None:
                tokenizer.pad_token = tokenizer.eos_token
        if model is None:
            assert model_name is not None
            model = transformers.AutoModelForCausalLM.from_pretrained(
                model_name, low_cpu_mem_usage=low_cpu_mem_usage,
                torch_dtype=torch_dtype,
                attn_implementation=attn_implementation
            )
        if device is not None:
            model.to(device)
        set_requires_grad(False, model)
        model.eval()
        self.tokenizer = tokenizer
        self.model = model
        self.device = device
        self.layer_names = [
            n
            for n, _ in model.named_modules()
            if (re.match(r"^(transformer|gpt_neox|model)\.(h|layers)\.\d+$", n))
        ]
        self.num_layers = len(self.layer_names)

    def __repr__(self):
        """String representation of this class.
        """
        return (
            f"ModelAndTokenizer(model: {type(self.model).__name__} "
            f"[{self.num_layers} layers], "
            f"tokenizer: {type(self.tokenizer).__name__})"
        )


def make_inputs(tokenizer, prompts, device="cuda"):
    """Prepare inputs to the model."""
    token_lists = [tokenizer.encode(p) for p in prompts]
    maxlen = max(len(t) for t in token_lists)
    if "[PAD]" in tokenizer.all_special_tokens:
        pad_id = tokenizer.all_special_ids[
            tokenizer.all_special_tokens.index("[PAD]")
        ]
    else:
        pad_id = 0
    input_ids = [
        [pad_id] * (maxlen - len(t)) + t for t in token_lists]
    attention_mask = [
        [0] * (maxlen - len(t)) + [1] * len(t) for t in token_lists
    ]
    return dict(
        input_ids=torch.tensor(input_ids).to(device),
        attention_mask=torch.tensor(attention_mask).to(device),
    )


def decode_tokens(tokenizer, token_array):
    if hasattr(token_array, "shape") and len(token_array.shape) > 1:
        return [decode_tokens(tokenizer, row) for row in token_array]
    return [tokenizer.decode([t]) for t in token_array]


def find_token_range(tokenizer, token_array, substring):
    """Find the tokens corresponding to the given substring in token_array."""
    toks = decode_tokens(tokenizer, token_array)
    whole_string = "".join(toks)
    char_loc = whole_string.index(substring)
    loc = 0
    tok_start, tok_end = None, None
    for i, t in enumerate(toks):
        loc += len(t)
        if tok_start is None and loc > char_loc:
            tok_start = i
        if tok_end is None and loc >= char_loc + len(substring):
            tok_end = i + 1
            break
    return (tok_start, tok_end)


def predict_from_input(model, inp):
    out = model(**inp)["logits"]
    probs = torch.softmax(out[:, -1], dim=1)
    p, preds = torch.max(probs, dim=1)
    return preds, p


def set_requires_grad(requires_grad, *models):
    for model in models:
        if isinstance(model, torch.nn.Module):
            for param in model.parameters():
                param.requires_grad = requires_grad
        elif isinstance(model, (torch.nn.Parameter, torch.Tensor)):
            model.requires_grad = requires_grad
        else:
            assert False, "unknown type %r" % type(model)


def find_last_continuous_position(source, target):
    target_len = len(target)
    last_position = -1

    for i in range(len(source) - target_len + 1):
        if source[i:i + target_len] == target:
            last_position = i + target_len - 1
    return last_position


from transformers import LogitsProcessor

class FirstTokenBanProcessor(LogitsProcessor):
    def __init__(self, banned_token_ids, initial_length):
        self.banned_token_ids = banned_token_ids
        self.initial_length = initial_length

    def __call__(self, input_ids, scores):
        if input_ids.shape[1] == self.initial_length:
            scores[:, self.banned_token_ids] = -float('inf')
        return scores

