from typing import Tuple

import torch
import torch.nn.functional as F

def terminal_safe_repr(token: str) -> str:
    result = []

    for ch in token:
        code = ord(ch)

        # Printable ASCII
        if 32 <= code <= 126:
            result.append(ch)

        # Common control characters with named escapes
        elif ch == '\n':
            result.append('\\n')
        elif ch == '\r':
            result.append('\\r')
        elif ch == '\t':
            result.append('\\t')
        elif ch == '\b':
            result.append('\\b')
        elif ch == '\f':
            result.append('\\f')
        elif ch == '\v':
            result.append('\\v')
        elif ch == '\a':
            result.append('\\a')
        elif ch == '\0':
            result.append('\\0')

        # Other ASCII control chars or DEL (0–31, 127)
        elif code < 128:
            result.append(f'\\x{code:02x}')

        # Non-ASCII → UTF-8 hex escape
        else:
            result.append(''.join(f'\\x{b:02x}' for b in ch.encode('utf-8')))

    return ''.join(result)


def format_token(token: str, length: int = 15) -> str:
    token = terminal_safe_repr(token)
    if len(token) > length:
        token = token[:length - 3] + '...'
    return token


def extract_hidden_states_ids(
    input_ids: torch.LongTensor,
    model: torch.nn.Module,
    layer_idx: int,
    grad: bool = False,
    batch: bool = False
) -> torch.Tensor:
    """
    Extract hidden states for the last token in a sequence of input IDs.
    Args:
        input_ids (torch.LongTensor): Token IDs for the sequence.
        model (torch.nn.Module): The language model.
        layer_idx (int): The layer index to target for inversion.
        grad (bool): Whether to compute gradients or not.
    Returns:
        torch.Tensor: Hidden state of the last token at the specified layer index.
    """
    if not grad:
        # Forward under no_grad and detach before returning
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                output_hidden_states=True,
                use_cache=False
            )
            hidden_states = outputs.hidden_states
            h = hidden_states[layer_idx]
            h = h if batch else h[0]

        return h.detach()
    else:
        # Forward normally, so gradients can flow
        outputs = model(
            input_ids=input_ids,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states
        h = hidden_states[layer_idx][0]
        h = h if batch else h[0]
        return h
    
def extract_all_hidden_states_ids(
    input_ids: torch.LongTensor,
    model: torch.nn.Module,
    grad: bool = False,
    batch: bool = False
) -> torch.Tensor:
    """
    Extract hidden states for the last token in a sequence of input IDs.
    Args:
        input_ids (torch.LongTensor): Token IDs for the sequence.
        model (torch.nn.Module): The language model.
        layer_idx (int): The layer index to target for inversion.
        grad (bool): Whether to compute gradients or not.
    Returns:
        torch.Tensor: Hidden state of the last token at the specified layer index.
    """
    if not grad:
        # Forward under no_grad and detach before returning
        with torch.no_grad():
            outputs = model(
                input_ids=input_ids,
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
            h = hidden_states.detach() if batch else [x.detach().squeeze(0) for x in hidden_states]

        return h
    else:
        # Forward normally, so gradients can flow
        outputs = model(
            input_ids=input_ids,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states
        h = hidden_states if batch else hidden_states[:,0,:,:]
        return h
    

def extract_hidden_states_prompt(
    prompt: str,
    model: torch.nn.Module,
    tokenizer,
    layer_idx: int,
    grad: bool = False,
    add_special_tokens: bool = False
) -> torch.Tensor:
    """
    Tokenize `prompt`, do a forward pass with output_hidden_states=True,
    and return the hidden vector of the *last* token at layer `layer_idx`.
    Args:
        prompt (str): the input string, e.g. "Harry".
        llm (nn.Module): a HuggingFace-style model (with embeddings + hidden_states).
        tokenizer: the corresponding tokenizer for `llm`.
        layer_idx (int): which hidden-layer index to extract (0=embeddings, 1=first block, etc.)
        grad (bool): whether to compute gradients or not.
    Returns:
        Tensor of shape (hidden_size,) = the last-token hidden state at `layer_idx`,
        computed under torch.no_grad() if grad=False, otherwise gradients are enabled.
    """
    device = next(model.parameters()).device
    # if input_ids is None:
    encoded = tokenizer(
        prompt, 
        return_tensors="pt",
        add_special_tokens=add_special_tokens,
        truncation=True,
        max_length=min(tokenizer.model_max_length, 2048)
    )
    input_ids = encoded["input_ids"].to(device)      # shape (1, seq_len)
    return extract_hidden_states_ids(
        input_ids=input_ids,
        model=model,
        layer_idx=layer_idx,
        grad=grad
    )


def extract_hidden_states(
    embeddings: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
    grad: bool = False
) -> torch.Tensor:
    """
    Tokenize `prompt`, do a forward pass with output_hidden_states=True,
    and return the hidden vector of the *last* token at layer `layer_idx`.

    Args:
        prompt (str): the input string, e.g. "Harry".
        model (nn.Module): a HuggingFace-style model (with embeddings + hidden_states).
        tokenizer: the corresponding tokenizer for `model`.
        layer_idx (int): which hidden-layer index to extract (0=embeddings, 1=first block, etc.)

    Returns:
        Tensor of shape (hidden_size,) = the last-token hidden state at `layer_idx`,
        computed under torch.no_grad().
    """
    if not grad:
        # Forward under no_grad and detach before returning
        with torch.no_grad():
            outputs = model(
                inputs_embeds=embeddings,
                output_hidden_states=True
            )
            hidden_states = outputs.hidden_states
            h = hidden_states[layer_idx][0]  # shape = (seq_len, hidden_size)

        return h.detach()
    else:
        # Forward normally, so gradients can flow
        outputs = model(
            inputs_embeds=embeddings,
            output_hidden_states=True
        )
        hidden_states = outputs.hidden_states
        h = hidden_states[layer_idx][0]  # shape = (seq_len, hidden_size)
        return h


def extract_hidden_states_iterative(
    input_ids: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
):
    target_input_embeddings = model.get_input_embeddings().weight[input_ids.squeeze(0)]
    target_input_embeddings = target_input_embeddings.detach().unsqueeze(0)

    target_embeddings = []

    with torch.no_grad():
        for i in range(1, target_input_embeddings.size(1) + 1):
            outputs = model(
                inputs_embeds=target_input_embeddings[:, :i, :], 
                output_hidden_states=True,
                use_cache=False,
            )
            target_output_embeddings = outputs.hidden_states[layer_idx][0, -1, :]
            target_embeddings.append(target_output_embeddings.detach())

    return torch.stack(target_embeddings)


def compute_last_token_embedding_grad_emb(
    embeddings: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> Tuple[torch.Tensor, float]:
    """
    Given a batch of precomputed token embeddings, run a forward pass
    up to `layer_idx`, compute the MSE loss against `h_target` for the last token,
    and return the gradient w.r.t. that last-token embedding plus the loss value.

    Gradients are computed only for the last embedding row; the others are treated as constants.

    Args:
        embeddings:  Tensor of shape (1, seq_len, hidden_size)
        llm:         A HuggingFace-style model supporting inputs_embeds + hidden_states.
        layer_idx:   Index of the hidden layer to extract (0=embeddings, 1=first block, ...).
        h_target:    Tensor of shape (hidden_size,) giving the desired hidden state for the last token.

    Returns:
        grad_last_embedding: Tensor of shape (hidden_size,) = d(loss)/d(embeddings[0,-1,:]).
        loss_val:            Scalar float = loss.item().
    """
    # Move to device
    device = next(model.parameters()).device
    embeddings = embeddings.to(device)
    h_target = h_target.to(device)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    fixed_embs = embeddings.clone().detach()
    last_emb = fixed_embs[:, -1:, :].clone().requires_grad_(True)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    inputs_embeds = torch.cat([fixed_embs[:, :-1, :], last_emb], dim=1)

    # Forward pass from custom embeddings
    outputs = model(
        inputs_embeds=inputs_embeds,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last = hidden_states[layer_idx][0, -1, :]

    # Compute MSE loss for last token
    loss = torch.nn.functional.mse_loss(h_last, h_target, reduction='sum')s    
    loss.backward()
    return last_emb.grad.squeeze(0, 1), loss


def compute_last_token_embedding_grad_emb_(
    embeddings: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> Tuple[torch.Tensor, float]:
    """
    Given a batch of precomputed token embeddings, run a forward pass
    up to `layer_idx`, compute the MSE loss against `h_target` for the last token,
    and return the gradient w.r.t. that last-token embedding plus the loss value.

    Gradients are computed only for the last embedding row; the others are treated as constants.

    Args:
        embeddings:  Tensor of shape (1, seq_len, hidden_size)
        llm:         A HuggingFace-style model supporting inputs_embeds + hidden_states.
        layer_idx:   Index of the hidden layer to extract (0=embeddings, 1=first block, ...).
        h_target:    Tensor of shape (hidden_size,) giving the desired hidden state for the last token.

    Returns:
        grad_last_embedding: Tensor of shape (hidden_size,) = d(loss)/d(embeddings[0,-1,:]).
        loss_val:            Scalar float = loss.item().
    """
    # Move to device
    device = next(model.parameters()).device
    dtype  = next(model.parameters()).dtype
    embeddings = embeddings.to(device)
    h_target = h_target.to(device)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    fixed_embs = embeddings.to(dtype).clone().detach()
    last_emb = fixed_embs[:, -1:, :].clone().requires_grad_(True)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    inputs_embeds = torch.cat([fixed_embs[:, :-1, :], last_emb], dim=1).to(dtype)

    # Forward pass from custom embeddings
    outputs = model(
        inputs_embeds=inputs_embeds,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last = hidden_states[layer_idx][0, -1, :]

    # Compute MSE loss for last token
    loss = torch.nn.functional.mse_loss(h_last.to(torch.float32), h_target.to(torch.float32), reduction='sum')

    # Compute gradient only w.r.t. last_emb
    
    # print(loss.item())

    loss.backward()
    return last_emb.grad.squeeze(0, 1).to(embeddings.dtype), loss

    # grad_last = torch.autograd.grad(loss, last_emb)[0]  # shape (1,1,hidden_size)
    # grad_last_embedding = grad_last[0, 0, :].detach().clone().requires_grad_(False)
    # return grad_last_embedding, loss


def compute_all_token_embedding_grad_emb(
    embeddings: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Given a batch of precomputed token embeddings, run a forward pass
    up to `layer_idx`, compute the MSE loss against `h_target` for the last token,
    and return the gradient w.r.t. that last-token embedding plus the loss value.

    Gradients are computed only for the last embedding row; the others are treated as constants.

    Args:
        embeddings:  Tensor of shape (1, seq_len, hidden_size)
        llm:         A HuggingFace-style model supporting inputs_embeds + hidden_states.
        layer_idx:   Index of the hidden layer to extract (0=embeddings, 1=first block, ...).
        h_target:    Tensor of shape (hidden_size,) giving the desired hidden state for the last token.

    Returns:
        grad_last_embedding: Tensor of shape (hidden_size,) = d(loss)/d(embeddings[0,-1,:]).
        loss_val:            Scalar float = loss.item().
    """
    # Move to device
    device = next(model.parameters()).device
    dtype  = next(model.parameters()).dtype
    embeddings = embeddings.to(device)
    h_target = h_target.to(device)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    inputs_embeds = embeddings.to(dtype).clone().detach().unsqueeze(0).requires_grad_(True)

    # Forward pass from custom embeddings
    outputs = model(
        inputs_embeds=inputs_embeds,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last = hidden_states[layer_idx][0, :, :]

    # Compute MSE loss for last token
    loss = torch.nn.functional.mse_loss(h_last.to(torch.float32), h_target.to(torch.float32), reduction='sum')

    loss.backward()

    return inputs_embeds.grad.squeeze(0).to(embeddings.dtype), loss # type: ignore

    # grad_last = torch.autograd.grad(loss, last_emb)[0]  # shape (1,1,hidden_size)
    # grad_last_embedding = grad_last[0, 0, :].detach().clone().requires_grad_(False)
    # return grad_last_embedding, loss


def compute_last_token_embedding_all_grad_emb(
    embeddings: torch.Tensor,
    model: torch.nn.Module,
    layer_idx: int,
    h_target: torch.Tensor,
) -> Tuple[torch.Tensor, float]:
    """
    Given a batch of precomputed token embeddings, run a forward pass
    up to `layer_idx`, compute the MSE loss against `h_target` for the last token,
    and return the gradient w.r.t. that last-token embedding plus the loss value.

    Gradients are computed only for the last embedding row; the others are treated as constants.

    Args:
        embeddings:  Tensor of shape (1, seq_len, hidden_size)
        llm:         A HuggingFace-style model supporting inputs_embeds + hidden_states.
        layer_idx:   Index of the hidden layer to extract (0=embeddings, 1=first block, ...).
        h_target:    Tensor of shape (hidden_size,) giving the desired hidden state for the last token.

    Returns:
        grad_last_embedding: Tensor of shape (hidden_size,) = d(loss)/d(embeddings[0,-1,:]).
        loss_val:            Scalar float = loss.item().
    """
    # Move to device
    device = next(model.parameters()).device
    embeddings = embeddings.to(device)
    h_target = h_target.to(device)

    # Reassemble inputs_embeds with fixed prefixes and grad-enabled last
    inputs_embeds = embeddings.clone().detach().unsqueeze(0).requires_grad_(True)

    # Forward pass from custom embeddings
    outputs = model(
        inputs_embeds=inputs_embeds,
        output_hidden_states=True
    )
    hidden_states = outputs.hidden_states
    h_last = hidden_states[layer_idx][0, :, :]

    # Compute MSE loss for last token
    # loss = torch.nn.functional.mse_loss(h_last[-1], h_target[-1], reduction='sum')
    loss = torch.nn.functional.mse_loss(h_last.mean(dim=0), h_target.mean(dim=0), reduction='sum')

    loss.backward()
    
    return inputs_embeds.grad.squeeze(0), loss

