import typing

import torch
from transformers import AutoTokenizer, AutoModel

# type for the tokenized representation of a string using the codebook ids
TokenizedCodebook: typing.TypeAlias = list[list[int]]

# type for the tokenized representation of a string using the text tokens from a tokenizer
TokenizedText: typing.TypeAlias = list[list[str]]


def discretize_text_batch(
    tokenizer: AutoTokenizer,
    model: AutoModel,
    clustering_model: "ClusteringDiscretizer",
    sentences_batch: list[str],
    device: str = "cuda",
    max_length: int = 200,
) -> tuple[TokenizedCodebook, TokenizedText]:
    """
    Discretize a batch of text by mapping tokens to their corresponding cluster IDs.

    Args:
    - tokenizer (AutoTokenizer): A HuggingFace tokenizer instance.
    - model (AutoModel): A HuggingFace model instance.
    - clustering_model: An instance of the clustering model.
    - sentences_batch (list[str]): A batch of sentences to discretize.
    - device (str): The device to use for computations ('cuda' or 'cpu').
    - max_length (int): The maximum length of input sequences.

    Returns:
    - list[list[int]]: A list of cluster ID sequences representing the discretized text.
    """
    # Tokenize the input text batch
    inputs = tokenizer(
        sentences_batch,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=max_length,
    ).to(device)
    with torch.no_grad():
        outputs = model(**inputs)
        token_embeddings = outputs.last_hidden_state
        attention_mask = inputs["attention_mask"]

    # Accumulate non-padded token embeddings
    selected_embeddings = accumulate_token_embeddings_masked_select(
        token_embeddings, attention_mask
    )
    embeddings_np = selected_embeddings.data.cpu().numpy()
    # Discretize the embeddings using the clustering model (get cluster IDs)
    cluster_ids = clustering_model.discrete(embeddings_np)

    # Compute lengths of non-padded tokens per sentence
    sentence_lengths = attention_mask.sum(dim=1).tolist()

    # Convert the flattened list of cluster IDs into a list of lists
    cluster_ids_per_sentence: TokenizedCodebook = []
    tokens_per_sentence: TokenizedText = []
    start_idx = 0
    for index, length in enumerate(sentence_lengths):
        # Get the cluster IDs for the current sentence
        end_idx = start_idx + length
        cluster_ids_per_sentence.append(cluster_ids[start_idx:end_idx].tolist())
        start_idx = end_idx
        # Get the tokens for the current sentence
        tokens_per_sentence.append(inputs.tokens(index)[:length])

    return cluster_ids_per_sentence, tokens_per_sentence


def accumulate_token_embeddings_masked_select(
    token_embeddings: torch.Tensor, attention_mask: torch.Tensor
) -> torch.Tensor:
    """
    Accumulate all token embeddings into a single tensor using masked_select, excluding padded tokens.

    Args:
    - token_embeddings (torch.Tensor): Tensor of shape (batch_size, seq_len, embedding_dim) containing token embeddings.
    - attention_mask (torch.Tensor): Tensor of shape (batch_size, seq_len) containing the attention mask.

    Returns:
    - selected_embeddings (torch.Tensor): Tensor of shape (num_tokens, embedding_dim) containing embeddings for non-padded tokens.
    """

    # Expand the attention mask to match the embedding dimensions
    expanded_mask = attention_mask.unsqueeze(-1).expand(token_embeddings.size())
    # Select only the non-padded embeddings
    selected_embeddings = torch.masked_select(token_embeddings, expanded_mask.bool())
    # Reshape the selected embeddings to 2D
    selected_embeddings = selected_embeddings.view(-1, token_embeddings.size(-1))
    return selected_embeddings
