# SPDX-License-Identifier: MIT
from __future__ import annotations
from typing import Iterable, List, Optional, Tuple, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .model_utils import get_eos_token_ids, build_position_ids


def pad_or_truncate_seq(emb_seq: torch.Tensor, max_tokens: int, pad_vec: torch.Tensor | None = None) -> torch.Tensor:
    t, d = emb_seq.shape
    if t >= max_tokens:
        return emb_seq[:max_tokens]
    if pad_vec is None:
        pad_vec = torch.zeros(d, device=emb_seq.device, dtype=emb_seq.dtype)
    pad = pad_vec.unsqueeze(0).repeat(max_tokens - t, 1)
    return torch.cat([emb_seq, pad], dim=0)


@torch.no_grad()
def generate_with_embedding_tensor_hf(
    tokens: Union[List[int], torch.LongTensor],
    max_new_tokens: int = 1,
    model: AutoModelForCausalLM | None = None,
    tokenizer: AutoTokenizer | None = None,
    capture: str = "token",
    eos_token_id: int | None = None,
) -> Tuple[torch.Tensor, str]:
    assert model is not None and tokenizer is not None
    device = next(model.parameters()).device
    if isinstance(tokens, list):
        input_ids = torch.tensor(tokens, dtype=torch.long, device=device).unsqueeze(0)
    elif isinstance(tokens, torch.Tensor):
        if tokens.dim() == 1:
            tokens = tokens.unsqueeze(0)
        input_ids = tokens.to(device=device, dtype=torch.long)
    else:
        raise TypeError("tokens should be List[int] or torch.LongTensor")

    if eos_token_id is None:
        eos_token_id = tokenizer.eos_token_id if tokenizer.eos_token_id is not None else 2

    if capture == "token":
        embedding_tensor = model.get_input_embeddings()(input_ids).detach().cpu()[0]
    elif capture == "after_pos":
        out0 = model(input_ids=input_ids, output_hidden_states=True, use_cache=True)
        embedding_tensor = out0.hidden_states[0].detach().cpu()[0]
    else:
        raise ValueError("capture must be 'token' or 'after_pos'")

    if capture != "after_pos":
        out0 = model(input_ids=input_ids, use_cache=True)
    past = out0.past_key_values
    logits = out0.logits[:, -1, :]
    next_token = torch.argmax(logits, dim=-1)
    generated_ids: List[int] = []

    for _ in range(max_new_tokens):
        out = model(input_ids=next_token.unsqueeze(0), past_key_values=past, use_cache=True)
        past = out.past_key_values
        logits = out.logits[:, -1, :]
        next_token = torch.argmax(logits, dim=-1)
        token_id = int(next_token.item())
        generated_ids.append(token_id)
        if token_id == eos_token_id:
            break

    new_text = "" if len(generated_ids) == 0 else tokenizer.decode(generated_ids, skip_special_tokens=False)
    return embedding_tensor, new_text


def word_token_embeddings_hf(
    word: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    capture_mode: str = "token",
) -> torch.Tensor:
    device = next(model.parameters()).device
    ids = tokenizer.encode(word, add_special_tokens=False)
    ids = torch.tensor(ids, dtype=torch.long, device=device).unsqueeze(0)
    if capture_mode == "token":
        seq = model.get_input_embeddings()(ids)[0]
    elif capture_mode == "resid0":
        out = model(input_ids=ids, output_hidden_states=True, use_cache=False)
        seq = out.hidden_states[0][0]
    else:
        raise ValueError("capture_mode must be 'token' or 'resid0'")
    return seq


@torch.no_grad()
def generate_with_new_embedding_hf(
    concatenated_tensor: torch.Tensor,
    input_sentence: str,
    model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    max_new_tokens: int = 100,
    k: int = 0,
    verbose: bool = False,
    eos_token_id: Optional[Union[int, Iterable[int]]] = None,
    add_special_tokens: bool = False,
) -> str:
    device = next(model.parameters()).device
    model.eval()
    input_ids = tokenizer.encode(input_sentence, add_special_tokens=add_special_tokens, return_tensors="pt").to(device)
    bsz, L = input_ids.shape
    if concatenated_tensor.dim() == 2:
        concatenated_tensor = concatenated_tensor.unsqueeze(0)
    assert concatenated_tensor.shape[0] == 1
    assert concatenated_tensor.shape[1] == L
    d_model = model.get_input_embeddings().weight.shape[1]
    assert concatenated_tensor.shape[2] == d_model

    embed_dtype = model.get_input_embeddings().weight.dtype
    inputs_embeds = concatenated_tensor.to(device=device, dtype=embed_dtype)
    attention_mask = torch.ones(bsz, L, dtype=torch.long, device=device)
    position_ids = build_position_ids(attention_mask)

    out = model(inputs_embeds=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, use_cache=True)
    past = out.past_key_values if hasattr(out, "past_key_values") else out["past_key_values"]
    logits = out.logits[:, -1, :]
    generated_ids: list[int] = []
    eos_set = get_eos_token_ids(tokenizer, model, eos_override=eos_token_id)

    for step in range(max_new_tokens):
        next_token = torch.argmax(logits, dim=-1)
        token_id = int(next_token.item())
        generated_ids.append(token_id)
        if eos_set and token_id in eos_set:
            break
        attention_mask = torch.cat([attention_mask, torch.ones(bsz, 1, device=device, dtype=torch.long)], dim=1)
        out = model(input_ids=next_token.view(1, 1), past_key_values=past, use_cache=True)
        past = out.past_key_values if hasattr(out, "past_key_values") else out["past_key_values"]
        logits = out.logits[:, -1, :]
    return tokenizer.decode(generated_ids, skip_special_tokens=False)