# SPDX-License-Identifier: MIT
from __future__ import annotations
import math
import re
from typing import Iterable, List, Optional, Tuple, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from .embeddings import generate_with_embedding_tensor_hf, generate_with_new_embedding_hf


def _get_sentence_embeddings_via_generate(input_ids: torch.LongTensor, model, tokenizer) -> torch.Tensor:
    emb, _ = generate_with_embedding_tensor_hf(input_ids.tolist(), model=model, tokenizer=tokenizer)
    if emb.dim() == 1:
        emb = emb.unsqueeze(0)
    return emb.to(next(model.parameters()).device)


def pipeline_on_target_model(
    sentence: str,
    delta: float,
    threshold: float,
    target_model: AutoModelForCausalLM,
    tokenizer: AutoTokenizer,
    net,
    net_device: str,
    max_tokens: int = 2,
    max_new_tokens: int = 50,
    capture_mode: str = "token",
    verbose: bool = False,
) -> str:
    assert getattr(tokenizer, "is_fast", False)
    device = next(target_model.parameters()).device
    d_model = target_model.get_input_embeddings().weight.shape[1]
    add_special_tokens_for_sentence = True
    enc = tokenizer(
        sentence,
        add_special_tokens=add_special_tokens_for_sentence,
        return_offsets_mapping=True,
        return_special_tokens_mask=True,
        return_tensors="pt",
    )
    input_ids = enc.input_ids.to(device)
    offsets = enc.offset_mapping[0].tolist()
    spec_mask = enc.special_tokens_mask[0].tolist()
    L = input_ids.shape[1]
    sent_emb, _ = generate_with_embedding_tensor_hf(
        input_ids[0].tolist(), max_new_tokens=0, model=target_model, tokenizer=tokenizer
    )
    if sent_emb.dim() == 1:
        sent_emb = sent_emb.unsqueeze(0)
    sent_emb = sent_emb.to(device)
    assert sent_emb.shape[0] == L
    word_spans = [(m.start(), m.end(), m.group(0)) for m in re.finditer(r"\w+|[^\w\s]", sentence, flags=re.UNICODE)]
    token_to_word = [-1] * L
    for t_idx, (s, e) in enumerate(offsets):
        if spec_mask[t_idx] == 1 or (s == 0 and e == 0):
            continue
        best_idx, best_ov = -1, -1
        for w_idx, (ws, we, _wtxt) in enumerate(word_spans):
            ov = max(0, min(e, we) - max(s, ws))
            if ov > best_ov:
                best_ov = ov
                best_idx = w_idx
        token_to_word[t_idx] = best_idx
    word_to_tokens: list[list[int]] = [[] for _ in range(len(word_spans))]
    for t_idx, w_idx in enumerate(token_to_word):
        if w_idx >= 0:
            word_to_tokens[w_idx].append(t_idx)
    for w_idx, tok_indices in enumerate(word_to_tokens):
        if not tok_indices:
            continue
        word_text = word_spans[w_idx][2]
        if re.fullmatch(r"[^\w\s]+", word_text, flags=re.UNICODE):
            continue
        span_len = len(tok_indices)
        slice_emb = sent_emb[tok_indices, :]
        if slice_emb.shape[0] >= max_tokens:
            x_seq = slice_emb[:max_tokens, :]
        else:
            pad = torch.zeros(max_tokens - slice_emb.shape[0], d_model, device=device, dtype=sent_emb.dtype)
            x_seq = torch.cat([slice_emb, pad], dim=0)
        x_seq = x_seq.to(device=net_device, dtype=sent_emb.dtype)
        with torch.no_grad():
            y = net(x_seq.unsqueeze(0).to(device=net_device, dtype=torch.float32))
            tox = float(y[0, 0].item())
            if tox > threshold:
                y_adj = y.clone()
                y_adj[0, 0] = y_adj[0, 0] - delta
                x_hat_flat = net.__class__.__dict__["reverse_fc_layer"](y_adj, net.fc) if "reverse_fc_layer" in net.__class__.__dict__ else None
                if x_hat_flat is None:
                    from .net import reverse_fc_layer
                    x_hat_flat = reverse_fc_layer(y_adj, net.fc)
                x_hat = x_hat_flat.view(1, max_tokens, d_model)[0]
                if span_len > max_tokens:
                    reps = math.ceil(span_len / max_tokens)
                    repl = x_hat.repeat(reps, 1)[:span_len, :].to(device)
                else:
                    repl = x_hat[:span_len, :].to(device)
                repl = repl.to(device=sent_emb.device, dtype=sent_emb.dtype)
                sent_emb[tok_indices, :] = repl
    new_text = generate_with_new_embedding_hf(
        concatenated_tensor=sent_emb.unsqueeze(0),
        input_sentence=sentence,
        model=target_model,
        tokenizer=tokenizer,
        max_new_tokens=max_new_tokens,
        k=0,
        verbose=verbose,
        eos_token_id=None,
        add_special_tokens=add_special_tokens_for_sentence,
    )
    return new_text