"""
Core text generation functions for language models.

This module provides the main generation functions including ragged batch generation
and loss computation functions.
"""

import logging
from typing import Literal

import torch
import torch.nn.functional as F
from transformers import (DynamicCache, HybridCache, PreTrainedModel,
                          PreTrainedTokenizerBase)

from ..types import JsonSchema
from .batching import with_max_batchsize
from .json_utils import JSONFilter, NullFilter, validate_json_strings
from .sampling import top_k_filtering, top_p_filtering
from .utils import get_stop_token_ids


@torch.no_grad()
def generate_ragged_batched(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    token_list: list[torch.LongTensor] | None = None,
    embedding_list: list[torch.FloatTensor] | None = None,
    initial_batch_size: int | None = None,
    use_cache: bool = True,
    verbose: bool = False,
    num_return_sequences: int = 1,
    **kwargs,
) -> list[list[str]] | list[list[torch.Tensor]]:
    """
    Generate completions for input_list, which can be either embeddings or tokens.
    Dynamically adjust batch size if an OOM error occurs.

    Args:
        model: The language model to use for generation.
        tokenizer: The tokenizer for the model.
        token_list: List of embeddings or tokens to generate from.
        embedding_list: List of embeddings or tokens to generate from.
        initial_batch_size: Starting batch size for generation.
        **kwargs: Additional arguments will be passed to `generate_ragged`.
    Returns:
        List of outputs strings generated by the model. list[list[str]] with 'shape' (B, num_return_sequences)
        or if return_tokens is True,
        List of tokens generated by the model. Shape: (B, num_return_sequences, <= max_new_tokens)
    """
    if token_list is not None and embedding_list is not None:
        raise ValueError("Only one of token_list or embedding_list should be provided.")
    input_type = "tokens" if token_list is not None else "embeddings"
    input_list = token_list if token_list is not None else embedding_list

    # Shorter sequences will come first to maximize batch size
    sorted_indexed_inputs = sorted(list(enumerate(input_list)), key=lambda x: x[1].size(0))

    # Duplicating each prompt for multiple return sequences here is faster because it
    # avoids the slower for-loop in generate_ragged_batched. We can't easily move this
    # inside generate_ragged directly because we need to do it outside of the
    # with_max_batchsize context.
    sorted_input_list = [it for _, item in sorted_indexed_inputs for it in [item] * num_return_sequences]
    original_indices = [index for index, _ in sorted_indexed_inputs]

    def func(chunk):
        """Wrapper function to handle a single chunk of inputs."""
        return generate_ragged(
            model=model,
            tokenizer=tokenizer,
            token_list=chunk if input_type == "tokens" else None,
            embedding_list=chunk if input_type == "embeddings" else None,
            use_cache=use_cache,
            num_return_sequences=1,
            **kwargs,
        )
    sorted_outputs = with_max_batchsize(func, sorted_input_list, initial_batch_size=initial_batch_size, verbose=verbose)

    # Unsort the outputs to match the original input order
    outputs = [None] * len(input_list)
    for i, original_index in enumerate(original_indices):
        outputs[original_index] = [sorted_outputs[i*num_return_sequences + j][0] for j in range(num_return_sequences)]
    return outputs


@torch.no_grad
def generate_ragged(
    model: PreTrainedModel,
    tokenizer: PreTrainedTokenizerBase,
    embedding_list: list[torch.FloatTensor] | None = None,
    token_list: list[torch.LongTensor] | None = None,
    max_new_tokens: int = 256,
    return_tokens: bool = False,
    padding_side: Literal["left", "right"] = "right",
    use_cache: bool = True,
    temperature: float = 0.0,
    top_p: float = 1.0,
    top_k: int = 0,
    num_return_sequences: int = 1,
    json_schema: JsonSchema = None
) -> list[list[str]] | list[list[torch.Tensor]]:
    """
    Generate completions for multiple prompts in a single batch.
    No KV-cache for left-padding yet.
    Heavily tested across models to be close to individual generations.
    This is far from trivial due to various padding (left/right) and masking issues.
    The final function is still not identical to individual generations, but it is close.
    The reason for this is probably that attention masks typically don't use -inf for
    masked tokens, but instead have values like -65504 for float16.
    This can lead to small differences in the final logits and thus the generated tokens.
    We are much closer to individual generations than HF model.generate, which often
    fails in mysterious ways for LLama & Qwen models. Generations for CircuitBreakers
    often look weird, but are actually just what the model would output with `generate`
    as well.

    Number of generations that are the same as single-batch:
        Model Name                         This function    HF generate
        cais/zephyr_7b_r2d2                      100/100        100/100
        ContinuousAT/Llama-2-7B-CAT              100/100         39/100
        ContinuousAT/Phi-CAT                      90/100         95/100
        ContinuousAT/Zephyr-CAT                  100/100        100/100
        google/gemma-2-2b-it                      55/100         56/100
        meta-llama/Meta-Llama-3.1-8B-Instruct     62/100         11/100
        meta-llama/Llama-2-7b-chat-hf             88/100         30/100
        microsoft/Phi-3-mini-4k-instruct          53/100         50/100
        mistralai/Mistral-7B-Instruct-v0.3        83/100         79/100
        qwen/Qwen2-7B-Instruct                    78/100         19/100
        ---------------------------------------------------------------
        Total                                   809/1000       579/1000

    Args:
        model: A pretrained model.
        tokenizer: A pretrained tokenizer.
        embedding_list: list[torch.Tensor], optional
            A list of embeddings for each prompt. Should not be padded and can be of different lengths.
        token_list: list[torch.Tensor], optional
            A list of tokens for each prompt. Should not be padded and can be of different lengths.
        max_new_tokens: The maximum number of tokens to generate for each prompt. If -1, generate until EOS token.
    Returns:
        A list of completions for each prompt.
        If return_tokens is True, return a list of lists of tokens.
    """
    if (embedding_list is None) == (token_list is None):
        raise ValueError("One of embedding_list or token_list must be provided.")
    if embedding_list is not None:
        assert all(e.ndim == 2 for e in embedding_list), "Embeddings must be 2D."
        embedding_list = [e.to(model.device) for e in embedding_list]
    if token_list is not None:
        assert all(t.ndim == 1 for t in token_list), "Tokens must be 1D."
        token_list = [t.to(model.device) for t in token_list]
        lengths = [t.size(0) for t in token_list]
        # Concatenate all tokens into a single tensor to get embeddings at once.
        concatenated_tokens = torch.cat(token_list)
        concatenated_embeddings = model.get_input_embeddings()(concatenated_tokens)
        embedding_list = list(torch.split(concatenated_embeddings, lengths))

    assert embedding_list is not None

    is_gemma = "gemma-2" in model.name_or_path or "gemma-3" in model.config.name_or_path
    B = len(embedding_list)

    if num_return_sequences == 0:
        if return_tokens:
            return torch.zeros((B, 0, max_new_tokens))
        else:
            return [[] for _ in range(B)]

    def sample_next_token(logits: torch.Tensor) -> torch.Tensor:
        """
        Sample a next token from the logits.
        Args:
            logits: (B, V)
        Returns:
            next_tokens: (B,)
        """
        if temperature > 0.0:
            logits = logits / temperature
            if top_p < 1.0:
                logits = top_p_filtering(logits, top_p)
            if top_k > 0:
                logits = top_k_filtering(logits, top_k)
            next_tokens = torch.multinomial(logits.softmax(dim=-1), num_samples=1)[:, 0]
        else:
            next_tokens = logits.argmax(dim=-1)
        return next_tokens

    stop_ids = get_stop_token_ids(model, tokenizer).to(model.device)
    embedding_layer = model.get_input_embeddings()
    idx_range = torch.arange(B, device=model.device)
    all_tokens = []
    token_filter = JSONFilter(json_schema, tokenizer, B) if json_schema else NullFilter()
    prev_tokens = torch.full((B,), tokenizer.pad_token_id, device=model.device)
    for _ in range(num_return_sequences):
        tokens = torch.full((B, max_new_tokens), tokenizer.pad_token_id, device=model.device)
        finished = torch.zeros(B, dtype=torch.bool, device=model.device)
        if padding_side == "left":
            if use_cache:
                raise NotImplementedError("KV-cache not implemented for left padding.")
            # Add left padding
            embeddings = torch.nn.utils.rnn.pad_sequence(
                [e.flip(0) for e in embedding_list], batch_first=True, padding_value=0
            ).flip(1)
            padded_embeddings = F.pad(embeddings, (0, 0, 0, max_new_tokens))
            # Create attention mask and position ids
            lengths = [
                {
                    "padding": embeddings.size(1) - e.size(0),
                    "generation": max_new_tokens - e.size(0),
                }
                for e in embedding_list
            ]
            attention_mask = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.ones([pl["generation"]])])
                    for pl in lengths
                ]
            ).to(model.device)
            position_ids = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.arange(pl["generation"])])
                    for pl in lengths
                ]
            ).long().to(model.device)
            next_token_idx = embeddings.size(1)
            for i in range(max_new_tokens):
                outputs = model(
                    inputs_embeds=padded_embeddings[:, :next_token_idx],
                    attention_mask=attention_mask[:, :next_token_idx],
                    position_ids=position_ids[:, :next_token_idx],
                    logits_to_keep=1,
                )
                logits = outputs.logits[:, -1].clone()
                logits = token_filter.step(prev_tokens, logits)
                next_tokens = sample_next_token(logits)
                padded_embeddings[idx_range, next_token_idx] = (
                    embedding_layer(next_tokens).detach()
                )
                tokens[:, i] = next_tokens
                finished |= torch.isin(next_tokens, stop_ids)
                if finished.all():
                    logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                    break
                prev_tokens.fill_(tokenizer.pad_token_id)     # reset sentinel
                prev_tokens[~finished] = next_tokens[~finished]
                next_token_idx += 1
        elif padding_side == "right":
            # Add right padding
            embeddings = torch.nn.utils.rnn.pad_sequence(
                [e for e in embedding_list], batch_first=True, padding_value=0
            )
            padded_embeddings = F.pad(embeddings, (0, 0, 0, max_new_tokens))
            next_token_idx = torch.tensor([e.size(0) for e in embedding_list], device=model.device)

            if use_cache:
                # This is the hot path so we have additional optimizations here.
                # As we generate tokens, we keep track of which prompts are completed,
                # and only generate tokens for the active prompts.
                # This is slightly (~20%) slower if all prompts have similar length,
                # but faster if prompts have different lengths and **saves VRAM**.

                # Fill prefix cache
                if not is_gemma:
                    past_key_values = DynamicCache()
                else:
                    config = model.config if hasattr(model.config, "cache_implementation") else model.config.text_config
                    past_key_values = HybridCache(
                        config=config,
                        max_batch_size=B,
                        max_cache_len=next_token_idx.max().item() + max_new_tokens,
                        device=model.device,
                        dtype=model.dtype,
                    )
                    # we need to iterate like this because key_cache and value_cache
                    # dont have a setter method
                    for layer in past_key_values.layers:
                        layer.keys = layer.keys.to(model.device)
                        layer.values = layer.values.to(model.device)

                if next_token_idx.min() > 1:
                    model(
                        inputs_embeds=padded_embeddings[:, : next_token_idx.min() - 1],
                        past_key_values=past_key_values,
                        use_cache=True,
                        logits_to_keep=1,  # we dont really need any logits, but 0 would return all
                    )
                # Caching with right padding is a bit tricky:
                # Consider this set of prompts:
                #
                #  0: x x x x x o o o o o p p p p p
                #  1: x x x x x x o o o o o p p p p
                #  2: x x x x x x x x x x o o o o o
                # Each batch element can be in one of three states:
                #   - A: cache_filling,
                #   - B: generating,
                #   - C: finished.
                # We forward A | B = ~C through the model, but only use outputs of B to
                # forward-fill the input embeddings.
                # This means that caching works best when sequences have similar length.
                lengths = torch.zeros(B, device=model.device, dtype=torch.long)
                for i in range(max_new_tokens + next_token_idx.max() - next_token_idx.min()):
                    cur_idx = next_token_idx.min()
                    generating = (next_token_idx == cur_idx) & ~finished

                    active_embeddings = padded_embeddings[~finished, cur_idx - 1 : cur_idx].clone()
                    assert active_embeddings.size(1) == 1
                    cache_position = torch.arange(cur_idx - 1, cur_idx, device=model.device)
                    logits = model(
                        inputs_embeds=active_embeddings,
                        cache_position=cache_position,
                        past_key_values=past_key_values,
                        use_cache=True,
                        logits_to_keep=1,
                    ).logits[:, 0].clone()  # (B, vocab_size)

                    logits_out = torch.empty((B, logits.size(1)), dtype=model.dtype, device=model.device)
                    logits_out[~finished] = logits
                    logits_out = token_filter.step(prev_tokens, logits_out)
                    next_tokens = torch.full((B,), tokenizer.eos_token_id, device=model.device)
                    next_tokens[generating] = sample_next_token(logits_out[generating])
                    prev_tokens.fill_(tokenizer.pad_token_id)
                    prev_tokens[generating] = next_tokens[generating]

                    padded_embeddings[idx_range[generating], next_token_idx.min()] = embedding_layer(next_tokens[generating])
                    tokens[generating, lengths[generating]] = next_tokens[generating]
                    # have to manually crop the past_key_values to the correct length
                    # since we only add a single step at a time
                    finished_at_this_step = torch.zeros_like(finished)
                    finished_at_this_step[generating] = torch.isin(next_tokens[generating], stop_ids) | (lengths[generating] + 1 == max_new_tokens)

                    if finished_at_this_step.any():
                        still_active = (~finished & ~finished_at_this_step)[~finished]

                        finished |= finished_at_this_step
                        if finished.all():
                            if i < max_new_tokens - 1:
                                logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                            break

                        for layer in past_key_values.layers:
                            layer.keys = layer.keys[still_active].clone()
                            layer.values = layer.values[still_active].clone()

                    next_token_idx[next_token_idx == cur_idx] += 1
                    lengths[generating] += 1
            else:
                for i in range(max_new_tokens):
                    logits = model(
                        inputs_embeds=padded_embeddings[:, : next_token_idx.max()]
                    ).logits[torch.arange(B), next_token_idx - 1].clone()
                    logits = token_filter.step(prev_tokens, logits)
                    next_tokens = sample_next_token(logits)
                    padded_embeddings[idx_range, next_token_idx] = embedding_layer(next_tokens)
                    tokens[:, i] = next_tokens
                    finished |= torch.isin(next_tokens, stop_ids)
                    if finished.all():
                        logging.info(f"Early exit after {i}/{max_new_tokens} tokens.")
                        break
                    prev_tokens.fill_(tokenizer.pad_token_id)     # reset sentinel
                    prev_tokens[~finished] = next_tokens[~finished]
                    next_token_idx += 1
        else:
            raise ValueError(f"Unknown padding_side: {padding_side}")
        all_tokens.append(tokens)

    all_tokens = torch.stack(all_tokens, dim=1).cpu()  # (B, N, T)
    if return_tokens:
        B, N, T = all_tokens.size()
        tokens = []
        for i in range(B):
            current_tokens = []
            for j in range(N):
                stop_idx = torch.where(torch.isin(all_tokens[i, j], stop_ids.cpu()))[0]
                stop_idx = stop_idx[0].item() if stop_idx.numel() > 0 else T
                current_tokens.append(all_tokens[i, j, :stop_idx])
            tokens.append(current_tokens)
        return tokens

    stop_tokens = tokenizer.convert_ids_to_tokens(stop_ids)
    completion = [tokenizer.batch_decode(all_tokens[i], skip_special_tokens=False) for i in range(B)]
    completion = [[min([c.split(t)[0] for t in stop_tokens], key=len) for c in completion[i]] for i in range(B)]
    if json_schema:
        validate_json_strings([gen for comp in completion for gen in comp], json_schema)

    return completion


def get_losses_batched(
    model: PreTrainedModel,
    targets: list[torch.Tensor],
    embedding_list: list[torch.Tensor] | None = None,
    token_list: list[torch.Tensor] | None = None,
    padding_side: Literal["left", "right"] = "right",
    initial_batch_size: int | None = None,
    verbose: bool = False,
) -> list[torch.Tensor]:
    """
    Get per-timestep losses for multiple ragged prompts in a single batch.
    No KV-cache for now.

    Args:
        model: A pretrained model.
        targets: A list of 1D tensors containing the target tokens for each prompt. Should already be shifted by one position w.r.t to the embeddings/tokens.
        embedding_list: list[torch.Tensor], optional
            A list of 2D tensors of embeddings for each prompt. Should not be padded and can be of different lengths.
        token_list: list[torch.Tensor], optional
            A list of tokens for each prompt. Should not be padded and can be of different lengths.
        padding_side: The side to pad the embeddings on.
        initial_batch_size: The initial batch size to use for the batched loss computation.
        verbose: Whether to print verbose output.
    Returns:
        A list of losses for each prompt.
    """
    if (embedding_list is None) == (token_list is None):
        raise ValueError("Either embedding_list or token_list must be provided.")
    if embedding_list is not None:
        assert all(e.ndim == 2 for e in embedding_list), "Embeddings must be 2D."
        embedding_list = [e.to(model.device) for e in embedding_list]
    if token_list is not None:
        assert all(t.ndim == 1 for t in token_list), "Tokens must be 1D."
        token_list = [t.to(model.device) for t in token_list]
        embedding_list = [
            model.get_input_embeddings()(t.unsqueeze(0))[0] for t in token_list
        ]
    assert embedding_list is not None
    assert all(t.ndim == 1 for t in targets), "Targets must be 1D."

    def get_losses_func(embedding_list, targets):
        # We first pad the embeddings to the maximum context length of the model.
        B = len(embedding_list)
        if padding_side == "left":
            print("Warning: Padding side 'left' is not recommended for get_batched_losses as it may yield nans.")
            # Add left padding
            embeddings = torch.nn.utils.rnn.pad_sequence(
                [e.flip(0) for e in embedding_list], batch_first=True, padding_value=0
            ).flip(1)

            targets_padded = torch.nn.utils.rnn.pad_sequence(
                [t.flip(0) for t in targets], batch_first=True, padding_value=0
            ).flip(1)
            # Create attention mask and position ids
            lengths = [
                {
                    "padding": embeddings.size(1) - e.size(0),
                    "generation": e.size(0),
                }
                for e in embedding_list
            ]
            attention_mask = torch.stack(
                [
                    torch.cat([torch.zeros(pl["padding"]), torch.ones([pl["generation"]])])
                    for pl in lengths
                ]
            ).to(model.device)
            position_ids = torch.stack(
                [
                    torch.cat(
                        [torch.zeros(pl["padding"]), torch.arange(pl["generation"])]
                    )
                    for pl in lengths
                ]
            ).long().to(model.device)

            outputs = model(
                inputs_embeds=embeddings,
                attention_mask=attention_mask,
                position_ids=position_ids,
            ).logits
            losses = F.cross_entropy(
                outputs.reshape(-1, outputs.size(-1)),
                targets_padded.view(-1).to(outputs.device),
                reduction="none",
            )
            losses = losses.view(B, -1)
            losses = [losses[i, -t.size(0) : -1] for i, t in enumerate(targets)]
        elif padding_side == "right":
            # Add right padding
            embeddings = torch.nn.utils.rnn.pad_sequence(
                [e for e in embedding_list], batch_first=True, padding_value=0
            )
            targets_padded = torch.nn.utils.rnn.pad_sequence(
                [t for t in targets], batch_first=True, padding_value=0
            ).to(model.device)
            outputs = model(inputs_embeds=embeddings).logits
            losses = F.cross_entropy(
                outputs.reshape(-1, outputs.size(-1)),
                targets_padded.view(-1),
                reduction="none",
            )
            losses = losses.view(B, -1)
            losses = [losses[i, : t.size(0) - 1] for i, t in enumerate(targets)]
        else:
            raise ValueError(f"Unknown padding_side: {padding_side}")
        return losses

    if initial_batch_size is None:
        initial_batch_size = len(embedding_list)

    # Shorter sequences will come first to maximize batch size
    sorted_indexed_inputs = sorted(list(enumerate(embedding_list)), key=lambda x: x[1].size(0))
    sorted_input_list = [item for _, item in sorted_indexed_inputs]
    original_indices = [index for index, _ in sorted_indexed_inputs]
    sorted_targets = [targets[i] for i in original_indices]
    losses = with_max_batchsize(get_losses_func, sorted_input_list, sorted_targets, initial_batch_size=initial_batch_size, verbose=verbose)
    # Reorder losses to match original order
    # Unsort the outputs to match the original input order
    outputs = [None] * len(embedding_list)
    for i, original_index in enumerate(original_indices):
        outputs[original_index] = losses[i] # (T,)
    return outputs
